liger-kernel 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +2 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +4 -6
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +11 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
- liger_kernel-0.5.3.dist-info/RECORD +69 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.2.dist-info/RECORD +0 -65
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -1,27 +1,23 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
2
5
|
|
|
3
6
|
import torch
|
|
7
|
+
|
|
4
8
|
from torch.nn import CrossEntropyLoss
|
|
5
9
|
from transformers.cache_utils import Cache
|
|
6
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
7
|
-
from transformers.models.mistral.modeling_mistral import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from transformers.utils import (
|
|
12
|
-
add_start_docstrings_to_model_forward,
|
|
13
|
-
replace_return_docstrings,
|
|
14
|
-
)
|
|
11
|
+
from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
|
|
12
|
+
from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
|
|
13
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
|
+
from transformers.utils import replace_return_docstrings
|
|
15
15
|
|
|
16
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
|
17
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
18
|
-
)
|
|
16
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
19
17
|
|
|
20
18
|
|
|
21
19
|
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
22
|
-
@replace_return_docstrings(
|
|
23
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
24
|
-
)
|
|
20
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
25
21
|
def lce_forward(
|
|
26
22
|
self,
|
|
27
23
|
input_ids: torch.LongTensor = None,
|
|
@@ -65,19 +61,11 @@ def lce_forward(
|
|
|
65
61
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
66
62
|
```"""
|
|
67
63
|
|
|
68
|
-
output_attentions =
|
|
69
|
-
output_attentions
|
|
70
|
-
if output_attentions is not None
|
|
71
|
-
else self.config.output_attentions
|
|
72
|
-
)
|
|
64
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
73
65
|
output_hidden_states = (
|
|
74
|
-
output_hidden_states
|
|
75
|
-
if output_hidden_states is not None
|
|
76
|
-
else self.config.output_hidden_states
|
|
77
|
-
)
|
|
78
|
-
return_dict = (
|
|
79
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
66
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
80
67
|
)
|
|
68
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
81
69
|
|
|
82
70
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
83
71
|
outputs = self.model(
|
|
@@ -1,27 +1,23 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
2
5
|
|
|
3
6
|
import torch
|
|
7
|
+
|
|
4
8
|
from torch.nn import CrossEntropyLoss
|
|
5
9
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
6
|
-
from transformers.models.mixtral.modeling_mixtral import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from transformers.utils import (
|
|
12
|
-
add_start_docstrings_to_model_forward,
|
|
13
|
-
replace_return_docstrings,
|
|
14
|
-
)
|
|
10
|
+
from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
|
|
11
|
+
from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
|
|
12
|
+
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
13
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
|
+
from transformers.utils import replace_return_docstrings
|
|
15
15
|
|
|
16
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
|
17
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
18
|
-
)
|
|
16
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
19
17
|
|
|
20
18
|
|
|
21
19
|
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
22
|
-
@replace_return_docstrings(
|
|
23
|
-
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
24
|
-
)
|
|
20
|
+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
25
21
|
def lce_forward_deprecated(
|
|
26
22
|
self,
|
|
27
23
|
input_ids: torch.LongTensor = None,
|
|
@@ -38,7 +34,7 @@ def lce_forward_deprecated(
|
|
|
38
34
|
cache_position: Optional[torch.LongTensor] = None,
|
|
39
35
|
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
|
40
36
|
r"""
|
|
41
|
-
Copy paste Mixtral's forward from
|
|
37
|
+
Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
|
|
42
38
|
|
|
43
39
|
|
|
44
40
|
Args:
|
|
@@ -66,25 +62,15 @@ def lce_forward_deprecated(
|
|
|
66
62
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
67
63
|
```"""
|
|
68
64
|
|
|
69
|
-
output_attentions =
|
|
70
|
-
output_attentions
|
|
71
|
-
if output_attentions is not None
|
|
72
|
-
else self.config.output_attentions
|
|
73
|
-
)
|
|
65
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
74
66
|
output_router_logits = (
|
|
75
|
-
output_router_logits
|
|
76
|
-
if output_router_logits is not None
|
|
77
|
-
else self.config.output_router_logits
|
|
67
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
78
68
|
)
|
|
79
69
|
|
|
80
70
|
output_hidden_states = (
|
|
81
|
-
output_hidden_states
|
|
82
|
-
if output_hidden_states is not None
|
|
83
|
-
else self.config.output_hidden_states
|
|
84
|
-
)
|
|
85
|
-
return_dict = (
|
|
86
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
71
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
87
72
|
)
|
|
73
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
88
74
|
|
|
89
75
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
90
76
|
outputs = self.model(
|
|
@@ -138,9 +124,7 @@ def lce_forward_deprecated(
|
|
|
138
124
|
attention_mask,
|
|
139
125
|
)
|
|
140
126
|
if labels is not None:
|
|
141
|
-
loss += self.router_aux_loss_coef * aux_loss.to(
|
|
142
|
-
loss.device
|
|
143
|
-
) # make sure to reside in the same device
|
|
127
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
144
128
|
|
|
145
129
|
if not return_dict:
|
|
146
130
|
output = (logits,) + outputs[1:]
|
|
@@ -160,9 +144,7 @@ def lce_forward_deprecated(
|
|
|
160
144
|
|
|
161
145
|
|
|
162
146
|
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
163
|
-
@replace_return_docstrings(
|
|
164
|
-
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
165
|
-
)
|
|
147
|
+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
166
148
|
# Ignore copy
|
|
167
149
|
def lce_forward(
|
|
168
150
|
self,
|
|
@@ -212,25 +194,15 @@ def lce_forward(
|
|
|
212
194
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
213
195
|
```"""
|
|
214
196
|
|
|
215
|
-
output_attentions =
|
|
216
|
-
output_attentions
|
|
217
|
-
if output_attentions is not None
|
|
218
|
-
else self.config.output_attentions
|
|
219
|
-
)
|
|
197
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
220
198
|
output_router_logits = (
|
|
221
|
-
output_router_logits
|
|
222
|
-
if output_router_logits is not None
|
|
223
|
-
else self.config.output_router_logits
|
|
199
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
224
200
|
)
|
|
225
201
|
|
|
226
202
|
output_hidden_states = (
|
|
227
|
-
output_hidden_states
|
|
228
|
-
if output_hidden_states is not None
|
|
229
|
-
else self.config.output_hidden_states
|
|
230
|
-
)
|
|
231
|
-
return_dict = (
|
|
232
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
203
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
233
204
|
)
|
|
205
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
234
206
|
|
|
235
207
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
236
208
|
outputs = self.model(
|
|
@@ -288,9 +260,7 @@ def lce_forward(
|
|
|
288
260
|
attention_mask,
|
|
289
261
|
)
|
|
290
262
|
if labels is not None:
|
|
291
|
-
loss += self.router_aux_loss_coef * aux_loss.to(
|
|
292
|
-
loss.device
|
|
293
|
-
) # make sure to reside in the same device
|
|
263
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
294
264
|
|
|
295
265
|
if not return_dict:
|
|
296
266
|
output = (logits,) + outputs[1:]
|
|
@@ -1,24 +1,22 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
2
5
|
|
|
3
6
|
import torch
|
|
7
|
+
|
|
4
8
|
from torch.nn import CrossEntropyLoss
|
|
5
9
|
from transformers.cache_utils import Cache
|
|
6
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
7
11
|
from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
|
|
8
|
-
from transformers.utils import
|
|
9
|
-
|
|
10
|
-
replace_return_docstrings,
|
|
11
|
-
)
|
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
+
from transformers.utils import replace_return_docstrings
|
|
12
14
|
|
|
13
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
|
14
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
15
|
-
)
|
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
19
|
-
@replace_return_docstrings(
|
|
20
|
-
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
|
|
21
|
-
)
|
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
22
20
|
def lce_forward_deprecated(
|
|
23
21
|
self,
|
|
24
22
|
input_ids: torch.LongTensor = None,
|
|
@@ -66,19 +64,11 @@ def lce_forward_deprecated(
|
|
|
66
64
|
I love the idea of snowflakes gently falling, each one
|
|
67
65
|
```
|
|
68
66
|
"""
|
|
69
|
-
output_attentions =
|
|
70
|
-
output_attentions
|
|
71
|
-
if output_attentions is not None
|
|
72
|
-
else self.config.output_attentions
|
|
73
|
-
)
|
|
67
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
74
68
|
output_hidden_states = (
|
|
75
|
-
output_hidden_states
|
|
76
|
-
if output_hidden_states is not None
|
|
77
|
-
else self.config.output_hidden_states
|
|
78
|
-
)
|
|
79
|
-
return_dict = (
|
|
80
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
69
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
81
70
|
)
|
|
71
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
82
72
|
|
|
83
73
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
84
74
|
outputs = self.model(
|
|
@@ -143,9 +133,7 @@ def lce_forward_deprecated(
|
|
|
143
133
|
|
|
144
134
|
|
|
145
135
|
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
146
|
-
@replace_return_docstrings(
|
|
147
|
-
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
|
|
148
|
-
)
|
|
136
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
149
137
|
def lce_forward(
|
|
150
138
|
self,
|
|
151
139
|
input_ids: torch.LongTensor = None,
|
|
@@ -198,19 +186,11 @@ def lce_forward(
|
|
|
198
186
|
I love the idea of snowflakes gently falling, each one
|
|
199
187
|
```
|
|
200
188
|
"""
|
|
201
|
-
output_attentions =
|
|
202
|
-
output_attentions
|
|
203
|
-
if output_attentions is not None
|
|
204
|
-
else self.config.output_attentions
|
|
205
|
-
)
|
|
189
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
206
190
|
output_hidden_states = (
|
|
207
|
-
output_hidden_states
|
|
208
|
-
if output_hidden_states is not None
|
|
209
|
-
else self.config.output_hidden_states
|
|
210
|
-
)
|
|
211
|
-
return_dict = (
|
|
212
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
191
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
213
192
|
)
|
|
193
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
214
194
|
|
|
215
195
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
216
196
|
outputs = self.model(
|
|
@@ -1,26 +1,22 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
2
5
|
|
|
3
6
|
import torch
|
|
7
|
+
|
|
4
8
|
from torch.nn import CrossEntropyLoss
|
|
5
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
6
|
-
from transformers.models.phi3.modeling_phi3 import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from transformers.utils import (
|
|
11
|
-
add_start_docstrings_to_model_forward,
|
|
12
|
-
replace_return_docstrings,
|
|
13
|
-
)
|
|
10
|
+
from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
|
|
11
|
+
from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
|
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
+
from transformers.utils import replace_return_docstrings
|
|
14
14
|
|
|
15
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
17
|
-
)
|
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
16
|
|
|
19
17
|
|
|
20
18
|
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(
|
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
23
|
-
)
|
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
24
20
|
def lce_forward_deprecated(
|
|
25
21
|
self,
|
|
26
22
|
input_ids: torch.LongTensor = None,
|
|
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
|
|
|
64
60
|
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
|
65
61
|
```"""
|
|
66
62
|
|
|
67
|
-
output_attentions =
|
|
68
|
-
output_attentions
|
|
69
|
-
if output_attentions is not None
|
|
70
|
-
else self.config.output_attentions
|
|
71
|
-
)
|
|
63
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
72
64
|
output_hidden_states = (
|
|
73
|
-
output_hidden_states
|
|
74
|
-
if output_hidden_states is not None
|
|
75
|
-
else self.config.output_hidden_states
|
|
76
|
-
)
|
|
77
|
-
return_dict = (
|
|
78
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
65
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
79
66
|
)
|
|
67
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
80
68
|
|
|
81
69
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
82
70
|
outputs = self.model(
|
|
@@ -138,9 +126,7 @@ def lce_forward_deprecated(
|
|
|
138
126
|
|
|
139
127
|
|
|
140
128
|
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
141
|
-
@replace_return_docstrings(
|
|
142
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
143
|
-
)
|
|
129
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
144
130
|
def lce_forward(
|
|
145
131
|
self,
|
|
146
132
|
input_ids: torch.LongTensor = None,
|
|
@@ -202,19 +188,11 @@ def lce_forward(
|
|
|
202
188
|
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
|
|
203
189
|
)
|
|
204
190
|
|
|
205
|
-
output_attentions =
|
|
206
|
-
output_attentions
|
|
207
|
-
if output_attentions is not None
|
|
208
|
-
else self.config.output_attentions
|
|
209
|
-
)
|
|
191
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
210
192
|
output_hidden_states = (
|
|
211
|
-
output_hidden_states
|
|
212
|
-
if output_hidden_states is not None
|
|
213
|
-
else self.config.output_hidden_states
|
|
214
|
-
)
|
|
215
|
-
return_dict = (
|
|
216
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
193
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
217
194
|
)
|
|
195
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
218
196
|
|
|
219
197
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
220
198
|
outputs = self.model(
|
|
@@ -1,26 +1,22 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
2
5
|
|
|
3
6
|
import torch
|
|
7
|
+
|
|
4
8
|
from torch.nn import CrossEntropyLoss
|
|
5
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
6
|
-
from transformers.models.qwen2.modeling_qwen2 import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from transformers.utils import (
|
|
11
|
-
add_start_docstrings_to_model_forward,
|
|
12
|
-
replace_return_docstrings,
|
|
13
|
-
)
|
|
10
|
+
from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
|
|
11
|
+
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
|
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
+
from transformers.utils import replace_return_docstrings
|
|
14
14
|
|
|
15
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
17
|
-
)
|
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
16
|
|
|
19
17
|
|
|
20
18
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(
|
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
23
|
-
)
|
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
24
20
|
def lce_forward_deprecated(
|
|
25
21
|
self,
|
|
26
22
|
input_ids: torch.LongTensor = None,
|
|
@@ -63,19 +59,11 @@ def lce_forward_deprecated(
|
|
|
63
59
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
64
60
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
65
61
|
```"""
|
|
66
|
-
output_attentions =
|
|
67
|
-
output_attentions
|
|
68
|
-
if output_attentions is not None
|
|
69
|
-
else self.config.output_attentions
|
|
70
|
-
)
|
|
62
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
71
63
|
output_hidden_states = (
|
|
72
|
-
output_hidden_states
|
|
73
|
-
if output_hidden_states is not None
|
|
74
|
-
else self.config.output_hidden_states
|
|
75
|
-
)
|
|
76
|
-
return_dict = (
|
|
77
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
64
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
78
65
|
)
|
|
66
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
79
67
|
|
|
80
68
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
81
69
|
outputs = self.model(
|
|
@@ -137,9 +125,7 @@ def lce_forward_deprecated(
|
|
|
137
125
|
|
|
138
126
|
|
|
139
127
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
140
|
-
@replace_return_docstrings(
|
|
141
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
142
|
-
)
|
|
128
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
143
129
|
def lce_forward(
|
|
144
130
|
self,
|
|
145
131
|
input_ids: torch.LongTensor = None,
|
|
@@ -187,19 +173,11 @@ def lce_forward(
|
|
|
187
173
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
188
174
|
```"""
|
|
189
175
|
|
|
190
|
-
output_attentions =
|
|
191
|
-
output_attentions
|
|
192
|
-
if output_attentions is not None
|
|
193
|
-
else self.config.output_attentions
|
|
194
|
-
)
|
|
176
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
195
177
|
output_hidden_states = (
|
|
196
|
-
output_hidden_states
|
|
197
|
-
if output_hidden_states is not None
|
|
198
|
-
else self.config.output_hidden_states
|
|
199
|
-
)
|
|
200
|
-
return_dict = (
|
|
201
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
178
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
202
179
|
)
|
|
180
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
203
181
|
|
|
204
182
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
205
183
|
outputs = self.model(
|
|
@@ -1,28 +1,24 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
2
5
|
|
|
3
6
|
import torch
|
|
7
|
+
|
|
4
8
|
from packaging import version
|
|
5
9
|
from torch.nn import CrossEntropyLoss
|
|
6
10
|
from transformers import __version__ as transformers_version
|
|
7
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from transformers.utils import (
|
|
13
|
-
add_start_docstrings_to_model_forward,
|
|
14
|
-
replace_return_docstrings,
|
|
15
|
-
)
|
|
11
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
|
|
12
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
|
|
13
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
|
14
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
+
from transformers.utils import replace_return_docstrings
|
|
16
16
|
|
|
17
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
|
18
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
19
|
-
)
|
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
20
18
|
|
|
21
19
|
|
|
22
20
|
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
|
23
|
-
@replace_return_docstrings(
|
|
24
|
-
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
25
|
-
)
|
|
21
|
+
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
26
22
|
def lce_forward(
|
|
27
23
|
self,
|
|
28
24
|
input_ids: torch.LongTensor = None,
|
|
@@ -40,6 +36,7 @@ def lce_forward(
|
|
|
40
36
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
41
37
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
42
38
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
39
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
43
40
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
|
44
41
|
r"""
|
|
45
42
|
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -82,19 +79,11 @@ def lce_forward(
|
|
|
82
79
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
83
80
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
84
81
|
```"""
|
|
85
|
-
output_attentions =
|
|
86
|
-
output_attentions
|
|
87
|
-
if output_attentions is not None
|
|
88
|
-
else self.config.output_attentions
|
|
89
|
-
)
|
|
82
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
90
83
|
output_hidden_states = (
|
|
91
|
-
output_hidden_states
|
|
92
|
-
if output_hidden_states is not None
|
|
93
|
-
else self.config.output_hidden_states
|
|
94
|
-
)
|
|
95
|
-
return_dict = (
|
|
96
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
84
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
97
85
|
)
|
|
86
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
98
87
|
|
|
99
88
|
if inputs_embeds is None:
|
|
100
89
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
@@ -137,16 +126,30 @@ def lce_forward(
|
|
|
137
126
|
if attention_mask is not None:
|
|
138
127
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
139
128
|
|
|
140
|
-
if version.parse(transformers_version) > version.parse("4.46.
|
|
129
|
+
if version.parse(transformers_version) > version.parse("4.46.3"):
|
|
141
130
|
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
|
|
142
131
|
# https://github.com/huggingface/transformers/issues/33401
|
|
143
132
|
# While correct, this breaks equivalence with past versions of Qwen2-VL from
|
|
144
133
|
# transformers and leads to failed tests or users noticing differences in results.
|
|
145
134
|
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
|
146
|
-
if
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
)
|
|
135
|
+
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
|
136
|
+
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
|
137
|
+
# calculate RoPE index once per generation in the pre-fill stage only
|
|
138
|
+
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
|
139
|
+
position_ids, rope_deltas = self.get_rope_index(
|
|
140
|
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
|
141
|
+
)
|
|
142
|
+
self.rope_deltas = rope_deltas
|
|
143
|
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
144
|
+
else:
|
|
145
|
+
batch_size, seq_length, _ = inputs_embeds.shape
|
|
146
|
+
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
|
147
|
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
148
|
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
149
|
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
150
|
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
151
|
+
position_ids = position_ids.add(delta)
|
|
152
|
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
150
153
|
|
|
151
154
|
outputs = self.model(
|
|
152
155
|
input_ids=None,
|
|
@@ -158,6 +161,7 @@ def lce_forward(
|
|
|
158
161
|
output_attentions=output_attentions,
|
|
159
162
|
output_hidden_states=output_hidden_states,
|
|
160
163
|
return_dict=return_dict,
|
|
164
|
+
cache_position=cache_position,
|
|
161
165
|
)
|
|
162
166
|
|
|
163
167
|
hidden_states = outputs[0]
|