liger-kernel 0.1.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +46 -0
- liger_kernel/ops/cross_entropy.py +130 -63
- liger_kernel/ops/experimental/embedding.py +143 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
- liger_kernel/ops/geglu.py +56 -44
- liger_kernel/ops/kl_div.py +258 -0
- liger_kernel/ops/layer_norm.py +236 -0
- liger_kernel/ops/rms_norm.py +220 -84
- liger_kernel/ops/rope.py +91 -84
- liger_kernel/ops/swiglu.py +50 -43
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +22 -0
- liger_kernel/transformers/auto_model.py +45 -0
- liger_kernel/transformers/cross_entropy.py +11 -1
- liger_kernel/transformers/experimental/embedding.py +28 -0
- liger_kernel/transformers/functional.py +19 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
- liger_kernel/transformers/geglu.py +4 -2
- liger_kernel/transformers/kl_div.py +14 -0
- liger_kernel/transformers/layer_norm.py +30 -0
- liger_kernel/transformers/model/gemma.py +138 -0
- liger_kernel/transformers/model/llama.py +1 -1
- liger_kernel/transformers/model/mistral.py +138 -0
- liger_kernel/transformers/model/mixtral.py +158 -0
- liger_kernel/transformers/model/phi3.py +136 -0
- liger_kernel/transformers/model/qwen2.py +135 -0
- liger_kernel/transformers/model/qwen2_vl.py +172 -0
- liger_kernel/transformers/monkey_patch.py +579 -14
- liger_kernel/transformers/rms_norm.py +23 -4
- liger_kernel/transformers/swiglu.py +24 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel-0.3.1.dist-info/METADATA +395 -0
- liger_kernel-0.3.1.dist-info/RECORD +42 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/WHEEL +1 -1
- liger_kernel-0.1.0.dist-info/METADATA +0 -16
- liger_kernel-0.1.0.dist-info/RECORD +0 -27
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/LICENSE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/NOTICE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/top_level.txt +0 -0
liger_kernel/ops/swiglu.py
CHANGED
|
@@ -14,7 +14,7 @@ def silu(x):
|
|
|
14
14
|
def _swiglu_forward_kernel(
|
|
15
15
|
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
16
16
|
):
|
|
17
|
-
program_id = tl.program_id(0)
|
|
17
|
+
program_id = tl.program_id(0).cast(tl.int64)
|
|
18
18
|
|
|
19
19
|
# locate start index
|
|
20
20
|
a_ptr += program_id * stride
|
|
@@ -35,7 +35,7 @@ def _swiglu_forward_kernel(
|
|
|
35
35
|
def _swiglu_backward_kernel(
|
|
36
36
|
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
37
37
|
):
|
|
38
|
-
program_id = tl.program_id(0)
|
|
38
|
+
program_id = tl.program_id(0).cast(tl.int64)
|
|
39
39
|
|
|
40
40
|
# locate start index
|
|
41
41
|
dc_ptr += program_id * stride
|
|
@@ -60,54 +60,61 @@ def _swiglu_backward_kernel(
|
|
|
60
60
|
tl.store(b_ptr + col_offsets, db_row, mask=mask)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
+
def swiglu_forward(a, b):
|
|
64
|
+
ori_shape = a.shape
|
|
65
|
+
|
|
66
|
+
n_cols = ori_shape[-1]
|
|
67
|
+
a = a.view(-1, n_cols)
|
|
68
|
+
b = b.view(-1, n_cols)
|
|
69
|
+
c = torch.empty_like(a)
|
|
70
|
+
n_rows = a.shape[0]
|
|
71
|
+
|
|
72
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
73
|
+
|
|
74
|
+
_swiglu_forward_kernel[(n_rows,)](
|
|
75
|
+
a,
|
|
76
|
+
b,
|
|
77
|
+
c,
|
|
78
|
+
c.stride(-2),
|
|
79
|
+
n_cols=n_cols,
|
|
80
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
81
|
+
num_warps=num_warps,
|
|
82
|
+
)
|
|
83
|
+
return a, b, c.view(*ori_shape)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def swiglu_backward(a, b, dc):
|
|
87
|
+
|
|
88
|
+
ori_shape = dc.shape
|
|
89
|
+
n_cols = ori_shape[-1]
|
|
90
|
+
dc = dc.view(-1, n_cols)
|
|
91
|
+
n_rows = dc.shape[0]
|
|
92
|
+
|
|
93
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
94
|
+
|
|
95
|
+
_swiglu_backward_kernel[(n_rows,)](
|
|
96
|
+
dc,
|
|
97
|
+
a,
|
|
98
|
+
b,
|
|
99
|
+
dc.stride(-2),
|
|
100
|
+
n_cols=n_cols,
|
|
101
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
102
|
+
num_warps=num_warps,
|
|
103
|
+
)
|
|
104
|
+
return a.view(*ori_shape), b.view(*ori_shape)
|
|
105
|
+
|
|
106
|
+
|
|
63
107
|
class LigerSiLUMulFunction(torch.autograd.Function):
|
|
64
108
|
@staticmethod
|
|
65
109
|
@ensure_contiguous
|
|
66
110
|
def forward(ctx, a, b):
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
n_cols = ori_shape[-1]
|
|
70
|
-
a = a.view(-1, n_cols)
|
|
71
|
-
b = b.view(-1, n_cols)
|
|
72
|
-
c = torch.zeros_like(a)
|
|
73
|
-
n_rows = a.shape[0]
|
|
74
|
-
|
|
75
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
76
|
-
|
|
77
|
-
_swiglu_forward_kernel[(n_rows,)](
|
|
78
|
-
a,
|
|
79
|
-
b,
|
|
80
|
-
c,
|
|
81
|
-
c.stride(-2),
|
|
82
|
-
n_cols=n_cols,
|
|
83
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
84
|
-
num_warps=num_warps,
|
|
85
|
-
)
|
|
86
|
-
|
|
111
|
+
a, b, c = swiglu_forward(a, b)
|
|
87
112
|
ctx.save_for_backward(a, b)
|
|
88
|
-
|
|
89
|
-
return c.view(*ori_shape)
|
|
113
|
+
return c
|
|
90
114
|
|
|
91
115
|
@staticmethod
|
|
92
116
|
@ensure_contiguous
|
|
93
117
|
def backward(ctx, dc):
|
|
94
|
-
|
|
95
|
-
ori_shape = dc.shape
|
|
96
|
-
n_cols = ori_shape[-1]
|
|
97
|
-
dc = dc.view(-1, n_cols)
|
|
98
118
|
a, b = ctx.saved_tensors
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
102
|
-
|
|
103
|
-
_swiglu_backward_kernel[(n_rows,)](
|
|
104
|
-
dc,
|
|
105
|
-
a,
|
|
106
|
-
b,
|
|
107
|
-
dc.stride(-2),
|
|
108
|
-
n_cols=n_cols,
|
|
109
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
110
|
-
num_warps=num_warps,
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
return a.view(*ori_shape), b.view(*ori_shape)
|
|
119
|
+
a, b = swiglu_backward(a, b, dc)
|
|
120
|
+
return a, b
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -1,3 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
|
|
3
|
+
See the original Unsloth repository at https://github.com/unslothai/unsloth.
|
|
4
|
+
|
|
5
|
+
The following line
|
|
6
|
+
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
|
|
7
|
+
is based on code from Unsloth, located at:
|
|
8
|
+
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
|
|
9
|
+
|
|
10
|
+
Modifications made by Yanning Chen, 2024.
|
|
11
|
+
"""
|
|
12
|
+
|
|
1
13
|
import functools
|
|
2
14
|
import importlib
|
|
3
15
|
from typing import Callable
|
|
@@ -1,6 +1,28 @@
|
|
|
1
|
+
from liger_kernel.transformers.auto_model import ( # noqa: F401
|
|
2
|
+
AutoLigerKernelForCausalLM,
|
|
3
|
+
)
|
|
4
|
+
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
|
5
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
|
|
6
|
+
LigerFusedLinearCrossEntropyLoss,
|
|
7
|
+
)
|
|
8
|
+
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
|
9
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
|
|
1
10
|
from liger_kernel.transformers.monkey_patch import ( # noqa: F401
|
|
11
|
+
_apply_liger_kernel,
|
|
12
|
+
_apply_liger_kernel_to_instance,
|
|
2
13
|
apply_liger_kernel_to_gemma,
|
|
14
|
+
apply_liger_kernel_to_gemma2,
|
|
3
15
|
apply_liger_kernel_to_llama,
|
|
4
16
|
apply_liger_kernel_to_mistral,
|
|
5
17
|
apply_liger_kernel_to_mixtral,
|
|
18
|
+
apply_liger_kernel_to_phi3,
|
|
19
|
+
apply_liger_kernel_to_qwen2,
|
|
20
|
+
apply_liger_kernel_to_qwen2_vl,
|
|
21
|
+
)
|
|
22
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
23
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
|
24
|
+
from liger_kernel.transformers.swiglu import ( # noqa: F401
|
|
25
|
+
LigerBlockSparseTop2MLP,
|
|
26
|
+
LigerPhi3SwiGLUMLP,
|
|
27
|
+
LigerSwiGLUMLP,
|
|
6
28
|
)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
3
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
|
4
|
+
|
|
5
|
+
from liger_kernel.transformers.monkey_patch import (
|
|
6
|
+
MODEL_TYPE_TO_APPLY_LIGER_FN,
|
|
7
|
+
_apply_liger_kernel,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _get_model_config(model_dir, **model_init_kwargs):
|
|
12
|
+
config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
|
|
13
|
+
return config
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
17
|
+
"""
|
|
18
|
+
This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
|
|
19
|
+
if applicable.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
24
|
+
model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)
|
|
25
|
+
|
|
26
|
+
# Determine the model type and apply the Liger Kernel if applicable
|
|
27
|
+
# Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
|
|
28
|
+
model_type = model_config.model_type
|
|
29
|
+
|
|
30
|
+
_apply_liger_kernel(model_type, **kwargs)
|
|
31
|
+
|
|
32
|
+
# Filter out kwargs that were passed to the apply_liger_* function, which will cause
|
|
33
|
+
# model initialization errors otherwise
|
|
34
|
+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
35
|
+
apply_fn_signature = inspect.signature(apply_fn)
|
|
36
|
+
|
|
37
|
+
applicable_kwargs = {
|
|
38
|
+
key: value
|
|
39
|
+
for key, value in kwargs.items()
|
|
40
|
+
if key not in apply_fn_signature.parameters
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return super().from_pretrained(
|
|
44
|
+
pretrained_model_name_or_path, *model_args, **applicable_kwargs
|
|
45
|
+
)
|
|
@@ -6,6 +6,16 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
|
6
6
|
class LigerCrossEntropyLoss(CrossEntropyLoss):
|
|
7
7
|
def __init__(self, *args, **kwargs):
|
|
8
8
|
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
|
|
9
|
+
assert (self.label_smoothing >= 0) and (
|
|
10
|
+
self.label_smoothing <= 1
|
|
11
|
+
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
|
|
12
|
+
assert self.reduction in {
|
|
13
|
+
"mean",
|
|
14
|
+
"sum",
|
|
15
|
+
"none",
|
|
16
|
+
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
|
|
9
17
|
|
|
10
18
|
def forward(self, _input, target):
|
|
11
|
-
return LigerCrossEntropyFunction.apply(
|
|
19
|
+
return LigerCrossEntropyFunction.apply(
|
|
20
|
+
_input, target, self.ignore_index, self.label_smoothing, self.reduction
|
|
21
|
+
)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerEmbedding(nn.Module):
|
|
10
|
+
def __init__(
|
|
11
|
+
self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
|
|
12
|
+
):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.num_embeddings = num_embeddings
|
|
15
|
+
self.embedding_dim = embedding_dim
|
|
16
|
+
self.padding_idx = padding_idx
|
|
17
|
+
self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
|
|
18
|
+
|
|
19
|
+
if padding_idx is not None:
|
|
20
|
+
with torch.no_grad():
|
|
21
|
+
self.weight[padding_idx].fill_(0)
|
|
22
|
+
|
|
23
|
+
def forward(self, indices):
|
|
24
|
+
embedded = LigerEmbeddingFunction.apply(self.weight, indices)
|
|
25
|
+
if self.padding_idx is not None:
|
|
26
|
+
embedded = embedded.clone()
|
|
27
|
+
embedded[indices == self.padding_idx] = 0
|
|
28
|
+
return embedded
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
2
|
+
from liger_kernel.ops.fused_linear_cross_entropy import (
|
|
3
|
+
LigerFusedLinearCrossEntropyFunction,
|
|
4
|
+
)
|
|
5
|
+
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
6
|
+
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
|
7
|
+
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
8
|
+
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
9
|
+
from liger_kernel.ops.rope import LigerRopeFunction
|
|
10
|
+
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
11
|
+
|
|
12
|
+
liger_swiglu = LigerSiLUMulFunction.apply
|
|
13
|
+
liger_cross_entropy = LigerCrossEntropyFunction.apply
|
|
14
|
+
liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
|
|
15
|
+
liger_geglu = LigerGELUMulFunction.apply
|
|
16
|
+
liger_rms_norm = LigerRMSNormFunction.apply
|
|
17
|
+
liger_rope = LigerRopeFunction.apply
|
|
18
|
+
liger_layer_norm = LigerLayerNormFunction.apply
|
|
19
|
+
liger_kl_div = LigerKLDivLossFunction.apply
|
|
@@ -9,7 +9,13 @@ class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
|
|
|
9
9
|
def __init__(self, *args, **kwargs):
|
|
10
10
|
super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
|
|
11
11
|
|
|
12
|
-
def forward(self, lin_weight, _input, target):
|
|
12
|
+
def forward(self, lin_weight, _input, target, bias=None):
|
|
13
13
|
return LigerFusedLinearCrossEntropyFunction.apply(
|
|
14
|
-
_input,
|
|
14
|
+
_input,
|
|
15
|
+
lin_weight,
|
|
16
|
+
target,
|
|
17
|
+
bias,
|
|
18
|
+
self.ignore_index,
|
|
19
|
+
self.label_smoothing,
|
|
20
|
+
self.reduction,
|
|
15
21
|
)
|
|
@@ -13,8 +13,10 @@ class LigerGEGLUMLP(nn.Module):
|
|
|
13
13
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
14
14
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
15
15
|
# TODO: support exact GELU
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
# Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh`
|
|
17
|
+
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175
|
|
18
|
+
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46
|
|
19
|
+
# So we can safely assume we use tanh approximation form all the time
|
|
18
20
|
|
|
19
21
|
def forward(self, x):
|
|
20
22
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LigerKLDIVLoss(nn.KLDivLoss):
|
|
7
|
+
def __init__(self, eps: float = 1e-10, *args, **kwargs):
|
|
8
|
+
super(LigerKLDIVLoss, self).__init__(*args, **kwargs)
|
|
9
|
+
self.eps = eps
|
|
10
|
+
|
|
11
|
+
def forward(self, y_pred, y_true):
|
|
12
|
+
return LigerKLDivLossFunction.apply(
|
|
13
|
+
y_pred, y_true, self.reduction, self.log_target, self.eps
|
|
14
|
+
)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerLayerNorm(nn.Module):
|
|
8
|
+
def __init__(self, hidden_size, eps=1e-6, bias=False, init_fn="ones"):
|
|
9
|
+
super().__init__()
|
|
10
|
+
assert init_fn in [
|
|
11
|
+
"ones",
|
|
12
|
+
"zeros",
|
|
13
|
+
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
14
|
+
self.hidden_size = hidden_size
|
|
15
|
+
self.eps = eps
|
|
16
|
+
self.weight = nn.Parameter(
|
|
17
|
+
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
|
18
|
+
)
|
|
19
|
+
self.bias = nn.Parameter(
|
|
20
|
+
torch.randn(hidden_size) if bias else torch.zeros(hidden_size)
|
|
21
|
+
)
|
|
22
|
+
self.variance_epsilon = eps
|
|
23
|
+
|
|
24
|
+
def forward(self, hidden_states):
|
|
25
|
+
return LigerLayerNormFunction.apply(
|
|
26
|
+
hidden_states, self.weight, self.bias, self.variance_epsilon
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def extra_repr(self):
|
|
30
|
+
return f"{self.hidden_size}, eps={self.eps}"
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.nn import CrossEntropyLoss
|
|
5
|
+
from transformers.cache_utils import Cache
|
|
6
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
7
|
+
from transformers.models.gemma.modeling_gemma import (
|
|
8
|
+
_CONFIG_FOR_DOC,
|
|
9
|
+
GEMMA_INPUTS_DOCSTRING,
|
|
10
|
+
)
|
|
11
|
+
from transformers.utils import (
|
|
12
|
+
add_start_docstrings_to_model_forward,
|
|
13
|
+
replace_return_docstrings,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|
17
|
+
LigerFusedLinearCrossEntropyLoss,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
22
|
+
@replace_return_docstrings(
|
|
23
|
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
24
|
+
)
|
|
25
|
+
def lce_forward(
|
|
26
|
+
self,
|
|
27
|
+
input_ids: torch.LongTensor = None,
|
|
28
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
29
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
30
|
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
31
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
32
|
+
labels: Optional[torch.LongTensor] = None,
|
|
33
|
+
use_cache: Optional[bool] = None,
|
|
34
|
+
output_attentions: Optional[bool] = None,
|
|
35
|
+
output_hidden_states: Optional[bool] = None,
|
|
36
|
+
return_dict: Optional[bool] = None,
|
|
37
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
38
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
39
|
+
r"""
|
|
40
|
+
|
|
41
|
+
copy paste transformers.models.gemma.modeling_gemma causalLM with loss replaced with liger fused cross entropy
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
45
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
46
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
47
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
55
|
+
|
|
56
|
+
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
|
57
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
|
58
|
+
|
|
59
|
+
>>> prompt = "What is your favorite condiment?"
|
|
60
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
61
|
+
|
|
62
|
+
>>> # Generate
|
|
63
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
64
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
65
|
+
"What is your favorite condiment?"
|
|
66
|
+
```"""
|
|
67
|
+
output_attentions = (
|
|
68
|
+
output_attentions
|
|
69
|
+
if output_attentions is not None
|
|
70
|
+
else self.config.output_attentions
|
|
71
|
+
)
|
|
72
|
+
output_hidden_states = (
|
|
73
|
+
output_hidden_states
|
|
74
|
+
if output_hidden_states is not None
|
|
75
|
+
else self.config.output_hidden_states
|
|
76
|
+
)
|
|
77
|
+
return_dict = (
|
|
78
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
82
|
+
outputs = self.model(
|
|
83
|
+
input_ids=input_ids,
|
|
84
|
+
attention_mask=attention_mask,
|
|
85
|
+
position_ids=position_ids,
|
|
86
|
+
past_key_values=past_key_values,
|
|
87
|
+
inputs_embeds=inputs_embeds,
|
|
88
|
+
use_cache=use_cache,
|
|
89
|
+
output_attentions=output_attentions,
|
|
90
|
+
output_hidden_states=output_hidden_states,
|
|
91
|
+
return_dict=return_dict,
|
|
92
|
+
cache_position=cache_position,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
hidden_states = outputs[0]
|
|
96
|
+
|
|
97
|
+
loss = None
|
|
98
|
+
logits = None
|
|
99
|
+
|
|
100
|
+
if self.training and (labels is not None):
|
|
101
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
102
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
103
|
+
|
|
104
|
+
# flatten
|
|
105
|
+
|
|
106
|
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
107
|
+
shift_labels = shift_labels.view(-1)
|
|
108
|
+
|
|
109
|
+
lce = LigerFusedLinearCrossEntropyLoss()
|
|
110
|
+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
logits = self.lm_head(hidden_states)
|
|
114
|
+
if labels is not None:
|
|
115
|
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
116
|
+
logits = logits.float()
|
|
117
|
+
# Shift so that tokens < n predict n
|
|
118
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
119
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
120
|
+
# Flatten the tokens
|
|
121
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
122
|
+
shift_labels = shift_labels.view(-1)
|
|
123
|
+
# Ensure tensors are on the same device
|
|
124
|
+
shift_labels = shift_labels.to(shift_logits.device)
|
|
125
|
+
loss_fct = CrossEntropyLoss()
|
|
126
|
+
loss = loss_fct(shift_logits, shift_labels)
|
|
127
|
+
|
|
128
|
+
if not return_dict:
|
|
129
|
+
output = (logits,) + outputs[1:]
|
|
130
|
+
return (loss,) + output if loss is not None else output
|
|
131
|
+
|
|
132
|
+
return CausalLMOutputWithPast(
|
|
133
|
+
loss=loss,
|
|
134
|
+
logits=logits,
|
|
135
|
+
past_key_values=outputs.past_key_values,
|
|
136
|
+
hidden_states=outputs.hidden_states,
|
|
137
|
+
attentions=outputs.attentions,
|
|
138
|
+
)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.nn import CrossEntropyLoss
|
|
5
|
+
from transformers.cache_utils import Cache
|
|
6
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
7
|
+
from transformers.models.mistral.modeling_mistral import (
|
|
8
|
+
_CONFIG_FOR_DOC,
|
|
9
|
+
MISTRAL_INPUTS_DOCSTRING,
|
|
10
|
+
)
|
|
11
|
+
from transformers.utils import (
|
|
12
|
+
add_start_docstrings_to_model_forward,
|
|
13
|
+
replace_return_docstrings,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|
17
|
+
LigerFusedLinearCrossEntropyLoss,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
22
|
+
@replace_return_docstrings(
|
|
23
|
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
24
|
+
)
|
|
25
|
+
def lce_forward(
|
|
26
|
+
self,
|
|
27
|
+
input_ids: torch.LongTensor = None,
|
|
28
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
29
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
30
|
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
31
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
32
|
+
labels: Optional[torch.LongTensor] = None,
|
|
33
|
+
use_cache: Optional[bool] = None,
|
|
34
|
+
output_attentions: Optional[bool] = None,
|
|
35
|
+
output_hidden_states: Optional[bool] = None,
|
|
36
|
+
return_dict: Optional[bool] = None,
|
|
37
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
38
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
39
|
+
r"""
|
|
40
|
+
Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
45
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
46
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
47
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
|
55
|
+
|
|
56
|
+
>>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
|
|
57
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
|
58
|
+
|
|
59
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
60
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
61
|
+
|
|
62
|
+
>>> # Generate
|
|
63
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
64
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
65
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
66
|
+
```"""
|
|
67
|
+
|
|
68
|
+
output_attentions = (
|
|
69
|
+
output_attentions
|
|
70
|
+
if output_attentions is not None
|
|
71
|
+
else self.config.output_attentions
|
|
72
|
+
)
|
|
73
|
+
output_hidden_states = (
|
|
74
|
+
output_hidden_states
|
|
75
|
+
if output_hidden_states is not None
|
|
76
|
+
else self.config.output_hidden_states
|
|
77
|
+
)
|
|
78
|
+
return_dict = (
|
|
79
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
83
|
+
outputs = self.model(
|
|
84
|
+
input_ids=input_ids,
|
|
85
|
+
attention_mask=attention_mask,
|
|
86
|
+
position_ids=position_ids,
|
|
87
|
+
past_key_values=past_key_values,
|
|
88
|
+
inputs_embeds=inputs_embeds,
|
|
89
|
+
use_cache=use_cache,
|
|
90
|
+
output_attentions=output_attentions,
|
|
91
|
+
output_hidden_states=output_hidden_states,
|
|
92
|
+
return_dict=return_dict,
|
|
93
|
+
cache_position=cache_position,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
hidden_states = outputs[0]
|
|
97
|
+
|
|
98
|
+
loss = None
|
|
99
|
+
logits = None
|
|
100
|
+
|
|
101
|
+
if self.training and (labels is not None):
|
|
102
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
103
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
104
|
+
|
|
105
|
+
# flatten tokens
|
|
106
|
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
107
|
+
shift_labels = shift_labels.view(-1)
|
|
108
|
+
|
|
109
|
+
lce = LigerFusedLinearCrossEntropyLoss()
|
|
110
|
+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
logits = self.lm_head(hidden_states)
|
|
114
|
+
if labels is not None:
|
|
115
|
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
116
|
+
logits = logits.float()
|
|
117
|
+
# Shift so that tokens < n predict n
|
|
118
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
119
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
120
|
+
# Flatten the tokens
|
|
121
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
122
|
+
shift_labels = shift_labels.view(-1)
|
|
123
|
+
# Ensure tensors are on the same device
|
|
124
|
+
shift_labels = shift_labels.to(shift_logits.device)
|
|
125
|
+
loss_fct = CrossEntropyLoss()
|
|
126
|
+
loss = loss_fct(shift_logits, shift_labels)
|
|
127
|
+
|
|
128
|
+
if not return_dict:
|
|
129
|
+
output = (logits,) + outputs[1:]
|
|
130
|
+
return (loss,) + output if loss is not None else output
|
|
131
|
+
|
|
132
|
+
return CausalLMOutputWithPast(
|
|
133
|
+
loss=loss,
|
|
134
|
+
logits=logits,
|
|
135
|
+
past_key_values=outputs.past_key_values,
|
|
136
|
+
hidden_states=outputs.hidden_states,
|
|
137
|
+
attentions=outputs.attentions,
|
|
138
|
+
)
|