liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241228022953__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +5 -12
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +12 -17
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +8 -24
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,7 @@ from typing import Optional
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from liger_kernel.ops.fused_linear_cross_entropy import
|
6
|
-
LigerFusedLinearCrossEntropyFunction,
|
7
|
-
)
|
5
|
+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
8
6
|
|
9
7
|
|
10
8
|
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
@@ -25,9 +23,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
25
23
|
"sum",
|
26
24
|
"none",
|
27
25
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
28
|
-
assert
|
29
|
-
softcap is None or softcap > 0
|
30
|
-
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
26
|
+
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
31
27
|
self.ignore_index = ignore_index
|
32
28
|
self.lse_square_scale = lse_square_scale
|
33
29
|
self.label_smoothing = label_smoothing
|
@@ -19,7 +19,4 @@ class LigerGEGLUMLP(nn.Module):
|
|
19
19
|
# So we can safely assume we use tanh approximation form all the time
|
20
20
|
|
21
21
|
def forward(self, x):
|
22
|
-
|
23
|
-
return self.down_proj(
|
24
|
-
LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
25
|
-
)
|
22
|
+
return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
@@ -27,19 +27,13 @@ class LigerGroupNorm(nn.Module):
|
|
27
27
|
self.num_channels = num_channels
|
28
28
|
self.num_groups = num_groups
|
29
29
|
self.eps = eps
|
30
|
-
self.weight = nn.Parameter(
|
31
|
-
|
32
|
-
)
|
33
|
-
self.bias = nn.Parameter(
|
34
|
-
torch.randn(num_channels) if bias else torch.zeros(num_channels)
|
35
|
-
)
|
30
|
+
self.weight = nn.Parameter(torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels))
|
31
|
+
self.bias = nn.Parameter(torch.randn(num_channels) if bias else torch.zeros(num_channels))
|
36
32
|
self.variance_epsilon = eps
|
37
33
|
|
38
34
|
def forward(self, hidden_states):
|
39
35
|
# hidden_states: (batch_size, num_channels, *)
|
40
|
-
assert (
|
41
|
-
hidden_states.dim() >= 3
|
42
|
-
), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
|
36
|
+
assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
|
43
37
|
assert (
|
44
38
|
hidden_states.size(1) == self.num_channels
|
45
39
|
), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
|
liger_kernel/transformers/jsd.py
CHANGED
@@ -67,6 +67,4 @@ class LigerJSD(torch.nn.Module):
|
|
67
67
|
log_p: torch.Tensor,
|
68
68
|
shift_labels: Optional[torch.LongTensor] = None,
|
69
69
|
):
|
70
|
-
return LigerJSDFunction.apply(
|
71
|
-
log_q, log_p, shift_labels, self.beta, self.ignore_index
|
72
|
-
)
|
70
|
+
return LigerJSDFunction.apply(log_q, log_p, shift_labels, self.beta, self.ignore_index)
|
@@ -9,6 +9,4 @@ class LigerKLDIVLoss(nn.KLDivLoss):
|
|
9
9
|
self.eps = eps
|
10
10
|
|
11
11
|
def forward(self, y_pred, y_true):
|
12
|
-
return LigerKLDivLossFunction.apply(
|
13
|
-
y_pred, y_true, self.reduction, self.log_target, self.eps
|
14
|
-
)
|
12
|
+
return LigerKLDivLossFunction.apply(y_pred, y_true, self.reduction, self.log_target, self.eps)
|
@@ -13,18 +13,12 @@ class LigerLayerNorm(nn.Module):
|
|
13
13
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
14
14
|
self.hidden_size = hidden_size
|
15
15
|
self.eps = eps
|
16
|
-
self.weight = nn.Parameter(
|
17
|
-
|
18
|
-
)
|
19
|
-
self.bias = nn.Parameter(
|
20
|
-
torch.randn(hidden_size) if bias else torch.zeros(hidden_size)
|
21
|
-
)
|
16
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
17
|
+
self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size))
|
22
18
|
self.variance_epsilon = eps
|
23
19
|
|
24
20
|
def forward(self, hidden_states):
|
25
|
-
return LigerLayerNormFunction.apply(
|
26
|
-
hidden_states, self.weight, self.bias, self.variance_epsilon
|
27
|
-
)
|
21
|
+
return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon)
|
28
22
|
|
29
23
|
def extra_repr(self):
|
30
24
|
return f"{self.hidden_size}, eps={self.eps}"
|
@@ -1,27 +1,23 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.cache_utils import Cache
|
6
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7
|
-
from transformers.models.gemma.modeling_gemma import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
from transformers.utils import (
|
12
|
-
add_start_docstrings_to_model_forward,
|
13
|
-
replace_return_docstrings,
|
14
|
-
)
|
11
|
+
from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
|
12
|
+
from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
|
13
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
14
|
+
from transformers.utils import replace_return_docstrings
|
15
15
|
|
16
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
17
|
-
LigerFusedLinearCrossEntropyLoss,
|
18
|
-
)
|
16
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
19
17
|
|
20
18
|
|
21
19
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
22
|
-
@replace_return_docstrings(
|
23
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
24
|
-
)
|
20
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
25
21
|
def lce_forward_deprecated(
|
26
22
|
self,
|
27
23
|
input_ids: torch.LongTensor = None,
|
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
|
|
64
60
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
65
61
|
"What is your favorite condiment?"
|
66
62
|
```"""
|
67
|
-
output_attentions =
|
68
|
-
output_attentions
|
69
|
-
if output_attentions is not None
|
70
|
-
else self.config.output_attentions
|
71
|
-
)
|
63
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
72
64
|
output_hidden_states = (
|
73
|
-
output_hidden_states
|
74
|
-
if output_hidden_states is not None
|
75
|
-
else self.config.output_hidden_states
|
76
|
-
)
|
77
|
-
return_dict = (
|
78
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
65
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
79
66
|
)
|
67
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
80
68
|
|
81
69
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
82
70
|
outputs = self.model(
|
@@ -139,9 +127,7 @@ def lce_forward_deprecated(
|
|
139
127
|
|
140
128
|
|
141
129
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
142
|
-
@replace_return_docstrings(
|
143
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
144
|
-
)
|
130
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
145
131
|
def lce_forward(
|
146
132
|
self,
|
147
133
|
input_ids: torch.LongTensor = None,
|
@@ -188,19 +174,11 @@ def lce_forward(
|
|
188
174
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
189
175
|
"What is your favorite condiment?"
|
190
176
|
```"""
|
191
|
-
output_attentions =
|
192
|
-
output_attentions
|
193
|
-
if output_attentions is not None
|
194
|
-
else self.config.output_attentions
|
195
|
-
)
|
177
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
196
178
|
output_hidden_states = (
|
197
|
-
output_hidden_states
|
198
|
-
if output_hidden_states is not None
|
199
|
-
else self.config.output_hidden_states
|
200
|
-
)
|
201
|
-
return_dict = (
|
202
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
179
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
203
180
|
)
|
181
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
204
182
|
|
205
183
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
206
184
|
outputs = self.model(
|
@@ -1,22 +1,20 @@
|
|
1
1
|
import logging
|
2
|
-
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
from typing import Tuple
|
5
|
+
from typing import Union
|
3
6
|
|
4
7
|
import torch
|
8
|
+
|
5
9
|
from torch.nn import CrossEntropyLoss
|
6
10
|
from transformers.cache_utils import HybridCache
|
7
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
8
|
-
from transformers.models.gemma2.modeling_gemma2 import
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
replace_return_docstrings,
|
15
|
-
)
|
16
|
-
|
17
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
18
|
-
LigerFusedLinearCrossEntropyLoss,
|
19
|
-
)
|
12
|
+
from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
|
13
|
+
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
|
14
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
15
|
+
from transformers.utils import replace_return_docstrings
|
16
|
+
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
20
18
|
|
21
19
|
logger = logging.getLogger(__name__)
|
22
20
|
|
@@ -63,19 +61,11 @@ def lce_forward_deprecated(
|
|
63
61
|
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
64
62
|
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
65
63
|
)
|
66
|
-
output_attentions =
|
67
|
-
output_attentions
|
68
|
-
if output_attentions is not None
|
69
|
-
else self.config.output_attentions
|
70
|
-
)
|
64
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
71
65
|
output_hidden_states = (
|
72
|
-
output_hidden_states
|
73
|
-
if output_hidden_states is not None
|
74
|
-
else self.config.output_hidden_states
|
75
|
-
)
|
76
|
-
return_dict = (
|
77
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
66
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
78
67
|
)
|
68
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
79
69
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
80
70
|
outputs = self.model(
|
81
71
|
input_ids=input_ids,
|
@@ -104,9 +94,7 @@ def lce_forward_deprecated(
|
|
104
94
|
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
105
95
|
shift_labels = shift_labels.view(-1)
|
106
96
|
|
107
|
-
lce = LigerFusedLinearCrossEntropyLoss(
|
108
|
-
softcap=self.config.final_logit_softcapping
|
109
|
-
)
|
97
|
+
lce = LigerFusedLinearCrossEntropyLoss(softcap=self.config.final_logit_softcapping)
|
110
98
|
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
111
99
|
|
112
100
|
else:
|
@@ -146,9 +134,7 @@ def lce_forward_deprecated(
|
|
146
134
|
|
147
135
|
|
148
136
|
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
149
|
-
@replace_return_docstrings(
|
150
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
151
|
-
)
|
137
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
152
138
|
def lce_forward(
|
153
139
|
self,
|
154
140
|
input_ids: torch.LongTensor = None,
|
@@ -201,19 +187,11 @@ def lce_forward(
|
|
201
187
|
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
202
188
|
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
203
189
|
)
|
204
|
-
output_attentions =
|
205
|
-
output_attentions
|
206
|
-
if output_attentions is not None
|
207
|
-
else self.config.output_attentions
|
208
|
-
)
|
190
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
209
191
|
output_hidden_states = (
|
210
|
-
output_hidden_states
|
211
|
-
if output_hidden_states is not None
|
212
|
-
else self.config.output_hidden_states
|
213
|
-
)
|
214
|
-
return_dict = (
|
215
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
192
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
216
193
|
)
|
194
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
217
195
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
218
196
|
outputs = self.model(
|
219
197
|
input_ids=input_ids,
|
@@ -1,30 +1,27 @@
|
|
1
|
-
from typing import TYPE_CHECKING
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
from typing import List
|
3
|
+
from typing import Optional
|
4
|
+
from typing import Tuple
|
5
|
+
from typing import Union
|
2
6
|
|
3
7
|
import torch
|
4
8
|
import torch.nn.functional as F
|
9
|
+
|
5
10
|
from torch.nn import CrossEntropyLoss
|
6
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7
|
-
from transformers.models.llama.modeling_llama import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
replace_return_docstrings,
|
14
|
-
)
|
15
|
-
|
16
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
17
|
-
LigerFusedLinearCrossEntropyLoss,
|
18
|
-
)
|
12
|
+
from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
|
13
|
+
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
|
14
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
15
|
+
from transformers.utils import replace_return_docstrings
|
16
|
+
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
19
18
|
|
20
19
|
if TYPE_CHECKING:
|
21
20
|
from transformers.cache_utils import Cache
|
22
21
|
|
23
22
|
|
24
23
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
25
|
-
@replace_return_docstrings(
|
26
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
27
|
-
)
|
24
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
28
25
|
def lce_forward_deprecated(
|
29
26
|
self,
|
30
27
|
input_ids: torch.LongTensor = None,
|
@@ -67,19 +64,11 @@ def lce_forward_deprecated(
|
|
67
64
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
68
65
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
69
66
|
```"""
|
70
|
-
output_attentions =
|
71
|
-
output_attentions
|
72
|
-
if output_attentions is not None
|
73
|
-
else self.config.output_attentions
|
74
|
-
)
|
67
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
75
68
|
output_hidden_states = (
|
76
|
-
output_hidden_states
|
77
|
-
if output_hidden_states is not None
|
78
|
-
else self.config.output_hidden_states
|
79
|
-
)
|
80
|
-
return_dict = (
|
81
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
69
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
82
70
|
)
|
71
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
83
72
|
|
84
73
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
85
74
|
outputs = self.model(
|
@@ -113,13 +102,8 @@ def lce_forward_deprecated(
|
|
113
102
|
|
114
103
|
else:
|
115
104
|
if self.config.pretraining_tp > 1:
|
116
|
-
lm_head_slices = self.lm_head.weight.split(
|
117
|
-
|
118
|
-
)
|
119
|
-
logits = [
|
120
|
-
F.linear(hidden_states, lm_head_slices[i])
|
121
|
-
for i in range(self.config.pretraining_tp)
|
122
|
-
]
|
105
|
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
106
|
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
123
107
|
logits = torch.cat(logits, dim=-1)
|
124
108
|
else:
|
125
109
|
logits = self.lm_head(hidden_states)
|
@@ -151,9 +135,7 @@ def lce_forward_deprecated(
|
|
151
135
|
|
152
136
|
|
153
137
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
154
|
-
@replace_return_docstrings(
|
155
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
156
|
-
)
|
138
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
157
139
|
def lce_forward(
|
158
140
|
self,
|
159
141
|
input_ids: torch.LongTensor = None,
|
@@ -201,19 +183,11 @@ def lce_forward(
|
|
201
183
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
202
184
|
```"""
|
203
185
|
|
204
|
-
output_attentions =
|
205
|
-
output_attentions
|
206
|
-
if output_attentions is not None
|
207
|
-
else self.config.output_attentions
|
208
|
-
)
|
186
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
209
187
|
output_hidden_states = (
|
210
|
-
output_hidden_states
|
211
|
-
if output_hidden_states is not None
|
212
|
-
else self.config.output_hidden_states
|
213
|
-
)
|
214
|
-
return_dict = (
|
215
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
188
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
216
189
|
)
|
190
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
217
191
|
|
218
192
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
219
193
|
outputs = self.model(
|
@@ -1,27 +1,23 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.cache_utils import Cache
|
6
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7
|
-
from transformers.models.mistral.modeling_mistral import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
from transformers.utils import (
|
12
|
-
add_start_docstrings_to_model_forward,
|
13
|
-
replace_return_docstrings,
|
14
|
-
)
|
11
|
+
from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
|
12
|
+
from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
|
13
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
14
|
+
from transformers.utils import replace_return_docstrings
|
15
15
|
|
16
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
17
|
-
LigerFusedLinearCrossEntropyLoss,
|
18
|
-
)
|
16
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
19
17
|
|
20
18
|
|
21
19
|
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
22
|
-
@replace_return_docstrings(
|
23
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
24
|
-
)
|
20
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
25
21
|
def lce_forward(
|
26
22
|
self,
|
27
23
|
input_ids: torch.LongTensor = None,
|
@@ -65,19 +61,11 @@ def lce_forward(
|
|
65
61
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
66
62
|
```"""
|
67
63
|
|
68
|
-
output_attentions =
|
69
|
-
output_attentions
|
70
|
-
if output_attentions is not None
|
71
|
-
else self.config.output_attentions
|
72
|
-
)
|
64
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
73
65
|
output_hidden_states = (
|
74
|
-
output_hidden_states
|
75
|
-
if output_hidden_states is not None
|
76
|
-
else self.config.output_hidden_states
|
77
|
-
)
|
78
|
-
return_dict = (
|
79
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
66
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
80
67
|
)
|
68
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
81
69
|
|
82
70
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
83
71
|
outputs = self.model(
|
@@ -1,27 +1,23 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
6
|
-
from transformers.models.mixtral.modeling_mixtral import
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
from transformers.utils import (
|
12
|
-
add_start_docstrings_to_model_forward,
|
13
|
-
replace_return_docstrings,
|
14
|
-
)
|
10
|
+
from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
|
11
|
+
from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
|
12
|
+
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
13
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
14
|
+
from transformers.utils import replace_return_docstrings
|
15
15
|
|
16
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
17
|
-
LigerFusedLinearCrossEntropyLoss,
|
18
|
-
)
|
16
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
19
17
|
|
20
18
|
|
21
19
|
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
22
|
-
@replace_return_docstrings(
|
23
|
-
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
24
|
-
)
|
20
|
+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
25
21
|
def lce_forward_deprecated(
|
26
22
|
self,
|
27
23
|
input_ids: torch.LongTensor = None,
|
@@ -66,25 +62,15 @@ def lce_forward_deprecated(
|
|
66
62
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
67
63
|
```"""
|
68
64
|
|
69
|
-
output_attentions =
|
70
|
-
output_attentions
|
71
|
-
if output_attentions is not None
|
72
|
-
else self.config.output_attentions
|
73
|
-
)
|
65
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
74
66
|
output_router_logits = (
|
75
|
-
output_router_logits
|
76
|
-
if output_router_logits is not None
|
77
|
-
else self.config.output_router_logits
|
67
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
78
68
|
)
|
79
69
|
|
80
70
|
output_hidden_states = (
|
81
|
-
output_hidden_states
|
82
|
-
if output_hidden_states is not None
|
83
|
-
else self.config.output_hidden_states
|
84
|
-
)
|
85
|
-
return_dict = (
|
86
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
71
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
87
72
|
)
|
73
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
88
74
|
|
89
75
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
90
76
|
outputs = self.model(
|
@@ -138,9 +124,7 @@ def lce_forward_deprecated(
|
|
138
124
|
attention_mask,
|
139
125
|
)
|
140
126
|
if labels is not None:
|
141
|
-
loss += self.router_aux_loss_coef * aux_loss.to(
|
142
|
-
loss.device
|
143
|
-
) # make sure to reside in the same device
|
127
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
144
128
|
|
145
129
|
if not return_dict:
|
146
130
|
output = (logits,) + outputs[1:]
|
@@ -160,9 +144,7 @@ def lce_forward_deprecated(
|
|
160
144
|
|
161
145
|
|
162
146
|
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
163
|
-
@replace_return_docstrings(
|
164
|
-
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
165
|
-
)
|
147
|
+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
166
148
|
# Ignore copy
|
167
149
|
def lce_forward(
|
168
150
|
self,
|
@@ -212,25 +194,15 @@ def lce_forward(
|
|
212
194
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
213
195
|
```"""
|
214
196
|
|
215
|
-
output_attentions =
|
216
|
-
output_attentions
|
217
|
-
if output_attentions is not None
|
218
|
-
else self.config.output_attentions
|
219
|
-
)
|
197
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
220
198
|
output_router_logits = (
|
221
|
-
output_router_logits
|
222
|
-
if output_router_logits is not None
|
223
|
-
else self.config.output_router_logits
|
199
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
224
200
|
)
|
225
201
|
|
226
202
|
output_hidden_states = (
|
227
|
-
output_hidden_states
|
228
|
-
if output_hidden_states is not None
|
229
|
-
else self.config.output_hidden_states
|
230
|
-
)
|
231
|
-
return_dict = (
|
232
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
203
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
233
204
|
)
|
205
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
234
206
|
|
235
207
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
236
208
|
outputs = self.model(
|
@@ -288,9 +260,7 @@ def lce_forward(
|
|
288
260
|
attention_mask,
|
289
261
|
)
|
290
262
|
if labels is not None:
|
291
|
-
loss += self.router_aux_loss_coef * aux_loss.to(
|
292
|
-
loss.device
|
293
|
-
) # make sure to reside in the same device
|
263
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
294
264
|
|
295
265
|
if not return_dict:
|
296
266
|
output = (logits,) + outputs[1:]
|