liger-kernel 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +2 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +4 -6
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +11 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
- liger_kernel-0.5.3.dist-info/RECORD +69 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.2.dist-info/RECORD +0 -65
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -1,31 +1,23 @@
|
|
|
1
|
-
from liger_kernel.transformers.auto_model import
|
|
2
|
-
AutoLigerKernelForCausalLM,
|
|
3
|
-
)
|
|
1
|
+
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
|
4
2
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
|
5
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
|
6
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
7
|
-
)
|
|
3
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
|
|
8
4
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
|
|
9
5
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
|
10
6
|
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
|
|
11
7
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
|
|
12
|
-
from liger_kernel.transformers.monkey_patch import
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
apply_liger_kernel_to_qwen2_vl,
|
|
24
|
-
)
|
|
8
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
|
9
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
|
10
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
|
11
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
12
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
13
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
|
14
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
|
15
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
|
16
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
|
17
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
|
18
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
25
19
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
26
20
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
|
27
|
-
from liger_kernel.transformers.swiglu import
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
LigerSwiGLUMLP,
|
|
31
|
-
)
|
|
21
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
|
22
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
|
23
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
|
|
3
|
-
from transformers import AutoConfig
|
|
3
|
+
from transformers import AutoConfig
|
|
4
|
+
from transformers import AutoModelForCausalLM
|
|
4
5
|
|
|
5
|
-
from liger_kernel.transformers.monkey_patch import
|
|
6
|
-
|
|
7
|
-
_apply_liger_kernel,
|
|
8
|
-
)
|
|
6
|
+
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
|
7
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
def _get_model_config(model_dir, **model_init_kwargs):
|
|
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
|
34
33
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
35
34
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
36
35
|
|
|
37
|
-
applicable_kwargs = {
|
|
38
|
-
key: value
|
|
39
|
-
for key, value in kwargs.items()
|
|
40
|
-
if key not in apply_fn_signature.parameters
|
|
41
|
-
}
|
|
36
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
|
|
42
37
|
|
|
43
|
-
return super().from_pretrained(
|
|
44
|
-
pretrained_model_name_or_path, *model_args, **applicable_kwargs
|
|
45
|
-
)
|
|
38
|
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
|
|
@@ -8,6 +8,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
|
8
8
|
class LigerCrossEntropyLoss(torch.nn.Module):
|
|
9
9
|
def __init__(
|
|
10
10
|
self,
|
|
11
|
+
weight: Optional[torch.FloatTensor] = None,
|
|
11
12
|
ignore_index: int = -100,
|
|
12
13
|
lse_square_scale: float = 0.0,
|
|
13
14
|
label_smoothing: float = 0.0,
|
|
@@ -19,17 +20,13 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
19
20
|
assert (label_smoothing >= 0) and (
|
|
20
21
|
label_smoothing <= 1
|
|
21
22
|
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
22
|
-
assert (label_smoothing >= 0) and (
|
|
23
|
-
label_smoothing <= 1
|
|
24
|
-
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
25
23
|
assert reduction in {
|
|
26
24
|
"mean",
|
|
27
25
|
"sum",
|
|
28
26
|
"none",
|
|
29
27
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
|
30
|
-
assert
|
|
31
|
-
|
|
32
|
-
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
28
|
+
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
29
|
+
self.weight = weight
|
|
33
30
|
self.ignore_index = ignore_index
|
|
34
31
|
self.lse_square_scale = lse_square_scale
|
|
35
32
|
self.label_smoothing = label_smoothing
|
|
@@ -41,6 +38,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
41
38
|
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
42
39
|
_input,
|
|
43
40
|
target,
|
|
41
|
+
self.weight,
|
|
44
42
|
self.ignore_index,
|
|
45
43
|
self.lse_square_scale,
|
|
46
44
|
self.label_smoothing,
|
|
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class LigerEmbedding(nn.Module):
|
|
10
|
-
def __init__(
|
|
11
|
-
self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
|
|
12
|
-
):
|
|
10
|
+
def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
|
|
13
11
|
super().__init__()
|
|
14
12
|
self.num_embeddings = num_embeddings
|
|
15
13
|
self.embedding_dim = embedding_dim
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
|
-
from liger_kernel.ops.fused_linear_cross_entropy import
|
|
5
|
-
LigerFusedLinearCrossEntropyFunction,
|
|
6
|
-
)
|
|
4
|
+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
|
7
5
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
|
8
6
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
9
7
|
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
@@ -34,6 +32,7 @@ def liger_cross_entropy(
|
|
|
34
32
|
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
35
33
|
input,
|
|
36
34
|
target,
|
|
35
|
+
weight,
|
|
37
36
|
ignore_index,
|
|
38
37
|
lse_square_scale,
|
|
39
38
|
label_smoothing,
|
|
@@ -51,23 +50,30 @@ def liger_fused_linear_cross_entropy(
|
|
|
51
50
|
weight,
|
|
52
51
|
target,
|
|
53
52
|
bias=None,
|
|
53
|
+
ce_weight=None,
|
|
54
54
|
ignore_index: int = -100,
|
|
55
55
|
lse_square_scale: float = 0.0,
|
|
56
56
|
label_smoothing: float = 0.0,
|
|
57
57
|
reduction: str = "mean",
|
|
58
58
|
softcap: Optional[float] = None,
|
|
59
|
+
return_z_loss: bool = False,
|
|
59
60
|
):
|
|
60
|
-
|
|
61
|
+
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
|
|
61
62
|
input,
|
|
62
63
|
weight,
|
|
63
64
|
target,
|
|
64
65
|
bias,
|
|
66
|
+
ce_weight,
|
|
65
67
|
ignore_index,
|
|
66
68
|
lse_square_scale,
|
|
67
69
|
label_smoothing,
|
|
68
70
|
reduction,
|
|
69
71
|
softcap,
|
|
72
|
+
return_z_loss,
|
|
70
73
|
)
|
|
74
|
+
if not return_z_loss:
|
|
75
|
+
return loss
|
|
76
|
+
return loss, z_loss
|
|
71
77
|
|
|
72
78
|
|
|
73
79
|
def liger_fused_linear_jsd(
|
|
@@ -159,9 +165,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
|
159
165
|
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
|
|
160
166
|
|
|
161
167
|
|
|
162
|
-
def liger_rms_norm(
|
|
163
|
-
X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
|
|
164
|
-
):
|
|
168
|
+
def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
|
|
165
169
|
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
|
|
166
170
|
|
|
167
171
|
|
|
@@ -2,19 +2,19 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from liger_kernel.ops.fused_linear_cross_entropy import
|
|
6
|
-
LigerFusedLinearCrossEntropyFunction,
|
|
7
|
-
)
|
|
5
|
+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
11
9
|
def __init__(
|
|
12
10
|
self,
|
|
11
|
+
ce_weight: Optional[torch.FloatTensor] = None,
|
|
13
12
|
ignore_index: int = -100,
|
|
14
13
|
lse_square_scale: float = 0.0,
|
|
15
14
|
label_smoothing: float = 0.0,
|
|
16
15
|
reduction: str = "mean",
|
|
17
16
|
softcap: Optional[float] = None,
|
|
17
|
+
return_z_loss: bool = False,
|
|
18
18
|
):
|
|
19
19
|
super().__init__()
|
|
20
20
|
assert (label_smoothing >= 0) and (
|
|
@@ -25,24 +25,29 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
25
25
|
"sum",
|
|
26
26
|
"none",
|
|
27
27
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
|
28
|
-
assert
|
|
29
|
-
|
|
30
|
-
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
28
|
+
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
29
|
+
self.ce_weight = ce_weight
|
|
31
30
|
self.ignore_index = ignore_index
|
|
32
31
|
self.lse_square_scale = lse_square_scale
|
|
33
32
|
self.label_smoothing = label_smoothing
|
|
34
33
|
self.reduction = reduction
|
|
35
34
|
self.softcap = softcap
|
|
35
|
+
self.return_z_loss = return_z_loss
|
|
36
36
|
|
|
37
37
|
def forward(self, lin_weight, _input, target, bias=None):
|
|
38
|
-
|
|
38
|
+
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
|
|
39
39
|
_input,
|
|
40
40
|
lin_weight,
|
|
41
41
|
target,
|
|
42
42
|
bias,
|
|
43
|
+
self.ce_weight,
|
|
43
44
|
self.ignore_index,
|
|
44
45
|
self.lse_square_scale,
|
|
45
46
|
self.label_smoothing,
|
|
46
47
|
self.reduction,
|
|
47
48
|
self.softcap,
|
|
49
|
+
self.return_z_loss,
|
|
48
50
|
)
|
|
51
|
+
if not self.return_z_loss:
|
|
52
|
+
return loss
|
|
53
|
+
return loss, z_loss
|
|
@@ -19,7 +19,4 @@ class LigerGEGLUMLP(nn.Module):
|
|
|
19
19
|
# So we can safely assume we use tanh approximation form all the time
|
|
20
20
|
|
|
21
21
|
def forward(self, x):
|
|
22
|
-
|
|
23
|
-
return self.down_proj(
|
|
24
|
-
LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
|
25
|
-
)
|
|
22
|
+
return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
@@ -27,19 +27,13 @@ class LigerGroupNorm(nn.Module):
|
|
|
27
27
|
self.num_channels = num_channels
|
|
28
28
|
self.num_groups = num_groups
|
|
29
29
|
self.eps = eps
|
|
30
|
-
self.weight = nn.Parameter(
|
|
31
|
-
|
|
32
|
-
)
|
|
33
|
-
self.bias = nn.Parameter(
|
|
34
|
-
torch.randn(num_channels) if bias else torch.zeros(num_channels)
|
|
35
|
-
)
|
|
30
|
+
self.weight = nn.Parameter(torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels))
|
|
31
|
+
self.bias = nn.Parameter(torch.randn(num_channels) if bias else torch.zeros(num_channels))
|
|
36
32
|
self.variance_epsilon = eps
|
|
37
33
|
|
|
38
34
|
def forward(self, hidden_states):
|
|
39
35
|
# hidden_states: (batch_size, num_channels, *)
|
|
40
|
-
assert (
|
|
41
|
-
hidden_states.dim() >= 3
|
|
42
|
-
), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
|
|
36
|
+
assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
|
|
43
37
|
assert (
|
|
44
38
|
hidden_states.size(1) == self.num_channels
|
|
45
39
|
), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
|
liger_kernel/transformers/jsd.py
CHANGED
|
@@ -67,6 +67,4 @@ class LigerJSD(torch.nn.Module):
|
|
|
67
67
|
log_p: torch.Tensor,
|
|
68
68
|
shift_labels: Optional[torch.LongTensor] = None,
|
|
69
69
|
):
|
|
70
|
-
return LigerJSDFunction.apply(
|
|
71
|
-
log_q, log_p, shift_labels, self.beta, self.ignore_index
|
|
72
|
-
)
|
|
70
|
+
return LigerJSDFunction.apply(log_q, log_p, shift_labels, self.beta, self.ignore_index)
|
|
@@ -9,6 +9,4 @@ class LigerKLDIVLoss(nn.KLDivLoss):
|
|
|
9
9
|
self.eps = eps
|
|
10
10
|
|
|
11
11
|
def forward(self, y_pred, y_true):
|
|
12
|
-
return LigerKLDivLossFunction.apply(
|
|
13
|
-
y_pred, y_true, self.reduction, self.log_target, self.eps
|
|
14
|
-
)
|
|
12
|
+
return LigerKLDivLossFunction.apply(y_pred, y_true, self.reduction, self.log_target, self.eps)
|
|
@@ -13,18 +13,12 @@ class LigerLayerNorm(nn.Module):
|
|
|
13
13
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
14
14
|
self.hidden_size = hidden_size
|
|
15
15
|
self.eps = eps
|
|
16
|
-
self.weight = nn.Parameter(
|
|
17
|
-
|
|
18
|
-
)
|
|
19
|
-
self.bias = nn.Parameter(
|
|
20
|
-
torch.randn(hidden_size) if bias else torch.zeros(hidden_size)
|
|
21
|
-
)
|
|
16
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
|
17
|
+
self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size))
|
|
22
18
|
self.variance_epsilon = eps
|
|
23
19
|
|
|
24
20
|
def forward(self, hidden_states):
|
|
25
|
-
return LigerLayerNormFunction.apply(
|
|
26
|
-
hidden_states, self.weight, self.bias, self.variance_epsilon
|
|
27
|
-
)
|
|
21
|
+
return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon)
|
|
28
22
|
|
|
29
23
|
def extra_repr(self):
|
|
30
24
|
return f"{self.hidden_size}, eps={self.eps}"
|
|
@@ -1,27 +1,23 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
2
5
|
|
|
3
6
|
import torch
|
|
7
|
+
|
|
4
8
|
from torch.nn import CrossEntropyLoss
|
|
5
9
|
from transformers.cache_utils import Cache
|
|
6
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
7
|
-
from transformers.models.gemma.modeling_gemma import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from transformers.utils import (
|
|
12
|
-
add_start_docstrings_to_model_forward,
|
|
13
|
-
replace_return_docstrings,
|
|
14
|
-
)
|
|
11
|
+
from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
|
|
12
|
+
from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
|
|
13
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
|
+
from transformers.utils import replace_return_docstrings
|
|
15
15
|
|
|
16
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
|
17
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
18
|
-
)
|
|
16
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
19
17
|
|
|
20
18
|
|
|
21
19
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
22
|
-
@replace_return_docstrings(
|
|
23
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
24
|
-
)
|
|
20
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
25
21
|
def lce_forward_deprecated(
|
|
26
22
|
self,
|
|
27
23
|
input_ids: torch.LongTensor = None,
|
|
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
|
|
|
64
60
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
65
61
|
"What is your favorite condiment?"
|
|
66
62
|
```"""
|
|
67
|
-
output_attentions =
|
|
68
|
-
output_attentions
|
|
69
|
-
if output_attentions is not None
|
|
70
|
-
else self.config.output_attentions
|
|
71
|
-
)
|
|
63
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
72
64
|
output_hidden_states = (
|
|
73
|
-
output_hidden_states
|
|
74
|
-
if output_hidden_states is not None
|
|
75
|
-
else self.config.output_hidden_states
|
|
76
|
-
)
|
|
77
|
-
return_dict = (
|
|
78
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
65
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
79
66
|
)
|
|
67
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
80
68
|
|
|
81
69
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
82
70
|
outputs = self.model(
|
|
@@ -139,9 +127,7 @@ def lce_forward_deprecated(
|
|
|
139
127
|
|
|
140
128
|
|
|
141
129
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
142
|
-
@replace_return_docstrings(
|
|
143
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
144
|
-
)
|
|
130
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
145
131
|
def lce_forward(
|
|
146
132
|
self,
|
|
147
133
|
input_ids: torch.LongTensor = None,
|
|
@@ -188,19 +174,11 @@ def lce_forward(
|
|
|
188
174
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
189
175
|
"What is your favorite condiment?"
|
|
190
176
|
```"""
|
|
191
|
-
output_attentions =
|
|
192
|
-
output_attentions
|
|
193
|
-
if output_attentions is not None
|
|
194
|
-
else self.config.output_attentions
|
|
195
|
-
)
|
|
177
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
196
178
|
output_hidden_states = (
|
|
197
|
-
output_hidden_states
|
|
198
|
-
if output_hidden_states is not None
|
|
199
|
-
else self.config.output_hidden_states
|
|
200
|
-
)
|
|
201
|
-
return_dict = (
|
|
202
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
179
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
203
180
|
)
|
|
181
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
204
182
|
|
|
205
183
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
206
184
|
outputs = self.model(
|
|
@@ -1,22 +1,20 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
from typing import Union
|
|
3
6
|
|
|
4
7
|
import torch
|
|
8
|
+
|
|
5
9
|
from torch.nn import CrossEntropyLoss
|
|
6
10
|
from transformers.cache_utils import HybridCache
|
|
7
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
8
|
-
from transformers.models.gemma2.modeling_gemma2 import
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
replace_return_docstrings,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|
18
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
19
|
-
)
|
|
12
|
+
from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
|
|
13
|
+
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
|
|
14
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
+
from transformers.utils import replace_return_docstrings
|
|
16
|
+
|
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
20
18
|
|
|
21
19
|
logger = logging.getLogger(__name__)
|
|
22
20
|
|
|
@@ -63,19 +61,11 @@ def lce_forward_deprecated(
|
|
|
63
61
|
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
|
64
62
|
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
|
65
63
|
)
|
|
66
|
-
output_attentions =
|
|
67
|
-
output_attentions
|
|
68
|
-
if output_attentions is not None
|
|
69
|
-
else self.config.output_attentions
|
|
70
|
-
)
|
|
64
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
71
65
|
output_hidden_states = (
|
|
72
|
-
output_hidden_states
|
|
73
|
-
if output_hidden_states is not None
|
|
74
|
-
else self.config.output_hidden_states
|
|
75
|
-
)
|
|
76
|
-
return_dict = (
|
|
77
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
66
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
78
67
|
)
|
|
68
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
79
69
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
80
70
|
outputs = self.model(
|
|
81
71
|
input_ids=input_ids,
|
|
@@ -104,9 +94,7 @@ def lce_forward_deprecated(
|
|
|
104
94
|
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
105
95
|
shift_labels = shift_labels.view(-1)
|
|
106
96
|
|
|
107
|
-
lce = LigerFusedLinearCrossEntropyLoss(
|
|
108
|
-
softcap=self.config.final_logit_softcapping
|
|
109
|
-
)
|
|
97
|
+
lce = LigerFusedLinearCrossEntropyLoss(softcap=self.config.final_logit_softcapping)
|
|
110
98
|
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
111
99
|
|
|
112
100
|
else:
|
|
@@ -146,9 +134,7 @@ def lce_forward_deprecated(
|
|
|
146
134
|
|
|
147
135
|
|
|
148
136
|
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
|
149
|
-
@replace_return_docstrings(
|
|
150
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
151
|
-
)
|
|
137
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
152
138
|
def lce_forward(
|
|
153
139
|
self,
|
|
154
140
|
input_ids: torch.LongTensor = None,
|
|
@@ -201,19 +187,11 @@ def lce_forward(
|
|
|
201
187
|
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
|
202
188
|
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
|
203
189
|
)
|
|
204
|
-
output_attentions =
|
|
205
|
-
output_attentions
|
|
206
|
-
if output_attentions is not None
|
|
207
|
-
else self.config.output_attentions
|
|
208
|
-
)
|
|
190
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
209
191
|
output_hidden_states = (
|
|
210
|
-
output_hidden_states
|
|
211
|
-
if output_hidden_states is not None
|
|
212
|
-
else self.config.output_hidden_states
|
|
213
|
-
)
|
|
214
|
-
return_dict = (
|
|
215
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
192
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
216
193
|
)
|
|
194
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
217
195
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
218
196
|
outputs = self.model(
|
|
219
197
|
input_ids=input_ids,
|
|
@@ -1,30 +1,27 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import List
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
from typing import Union
|
|
2
6
|
|
|
3
7
|
import torch
|
|
4
8
|
import torch.nn.functional as F
|
|
9
|
+
|
|
5
10
|
from torch.nn import CrossEntropyLoss
|
|
6
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
7
|
-
from transformers.models.llama.modeling_llama import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
replace_return_docstrings,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|
17
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
18
|
-
)
|
|
12
|
+
from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
|
|
13
|
+
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
|
|
14
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
+
from transformers.utils import replace_return_docstrings
|
|
16
|
+
|
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
19
18
|
|
|
20
19
|
if TYPE_CHECKING:
|
|
21
20
|
from transformers.cache_utils import Cache
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
25
|
-
@replace_return_docstrings(
|
|
26
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
27
|
-
)
|
|
24
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
28
25
|
def lce_forward_deprecated(
|
|
29
26
|
self,
|
|
30
27
|
input_ids: torch.LongTensor = None,
|
|
@@ -67,19 +64,11 @@ def lce_forward_deprecated(
|
|
|
67
64
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
68
65
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
69
66
|
```"""
|
|
70
|
-
output_attentions =
|
|
71
|
-
output_attentions
|
|
72
|
-
if output_attentions is not None
|
|
73
|
-
else self.config.output_attentions
|
|
74
|
-
)
|
|
67
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
75
68
|
output_hidden_states = (
|
|
76
|
-
output_hidden_states
|
|
77
|
-
if output_hidden_states is not None
|
|
78
|
-
else self.config.output_hidden_states
|
|
79
|
-
)
|
|
80
|
-
return_dict = (
|
|
81
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
69
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
82
70
|
)
|
|
71
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
83
72
|
|
|
84
73
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
85
74
|
outputs = self.model(
|
|
@@ -113,13 +102,8 @@ def lce_forward_deprecated(
|
|
|
113
102
|
|
|
114
103
|
else:
|
|
115
104
|
if self.config.pretraining_tp > 1:
|
|
116
|
-
lm_head_slices = self.lm_head.weight.split(
|
|
117
|
-
|
|
118
|
-
)
|
|
119
|
-
logits = [
|
|
120
|
-
F.linear(hidden_states, lm_head_slices[i])
|
|
121
|
-
for i in range(self.config.pretraining_tp)
|
|
122
|
-
]
|
|
105
|
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
|
106
|
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
123
107
|
logits = torch.cat(logits, dim=-1)
|
|
124
108
|
else:
|
|
125
109
|
logits = self.lm_head(hidden_states)
|
|
@@ -151,9 +135,7 @@ def lce_forward_deprecated(
|
|
|
151
135
|
|
|
152
136
|
|
|
153
137
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
154
|
-
@replace_return_docstrings(
|
|
155
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
156
|
-
)
|
|
138
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
157
139
|
def lce_forward(
|
|
158
140
|
self,
|
|
159
141
|
input_ids: torch.LongTensor = None,
|
|
@@ -201,19 +183,11 @@ def lce_forward(
|
|
|
201
183
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
202
184
|
```"""
|
|
203
185
|
|
|
204
|
-
output_attentions =
|
|
205
|
-
output_attentions
|
|
206
|
-
if output_attentions is not None
|
|
207
|
-
else self.config.output_attentions
|
|
208
|
-
)
|
|
186
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
209
187
|
output_hidden_states = (
|
|
210
|
-
output_hidden_states
|
|
211
|
-
if output_hidden_states is not None
|
|
212
|
-
else self.config.output_hidden_states
|
|
213
|
-
)
|
|
214
|
-
return_dict = (
|
|
215
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
188
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
216
189
|
)
|
|
190
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
217
191
|
|
|
218
192
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
219
193
|
outputs = self.model(
|