liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +5 -11
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,22 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.cache_utils import Cache
|
6
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7
11
|
from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
|
8
|
-
from transformers.utils import
|
9
|
-
|
10
|
-
replace_return_docstrings,
|
11
|
-
)
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
13
|
+
from transformers.utils import replace_return_docstrings
|
12
14
|
|
13
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
14
|
-
LigerFusedLinearCrossEntropyLoss,
|
15
|
-
)
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
16
16
|
|
17
17
|
|
18
18
|
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
19
|
-
@replace_return_docstrings(
|
20
|
-
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
|
21
|
-
)
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
22
20
|
def lce_forward_deprecated(
|
23
21
|
self,
|
24
22
|
input_ids: torch.LongTensor = None,
|
@@ -66,19 +64,11 @@ def lce_forward_deprecated(
|
|
66
64
|
I love the idea of snowflakes gently falling, each one
|
67
65
|
```
|
68
66
|
"""
|
69
|
-
output_attentions =
|
70
|
-
output_attentions
|
71
|
-
if output_attentions is not None
|
72
|
-
else self.config.output_attentions
|
73
|
-
)
|
67
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
74
68
|
output_hidden_states = (
|
75
|
-
output_hidden_states
|
76
|
-
if output_hidden_states is not None
|
77
|
-
else self.config.output_hidden_states
|
78
|
-
)
|
79
|
-
return_dict = (
|
80
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
69
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
81
70
|
)
|
71
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
82
72
|
|
83
73
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
84
74
|
outputs = self.model(
|
@@ -143,9 +133,7 @@ def lce_forward_deprecated(
|
|
143
133
|
|
144
134
|
|
145
135
|
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
146
|
-
@replace_return_docstrings(
|
147
|
-
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
|
148
|
-
)
|
136
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
149
137
|
def lce_forward(
|
150
138
|
self,
|
151
139
|
input_ids: torch.LongTensor = None,
|
@@ -198,19 +186,11 @@ def lce_forward(
|
|
198
186
|
I love the idea of snowflakes gently falling, each one
|
199
187
|
```
|
200
188
|
"""
|
201
|
-
output_attentions =
|
202
|
-
output_attentions
|
203
|
-
if output_attentions is not None
|
204
|
-
else self.config.output_attentions
|
205
|
-
)
|
189
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
206
190
|
output_hidden_states = (
|
207
|
-
output_hidden_states
|
208
|
-
if output_hidden_states is not None
|
209
|
-
else self.config.output_hidden_states
|
210
|
-
)
|
211
|
-
return_dict = (
|
212
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
191
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
213
192
|
)
|
193
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
214
194
|
|
215
195
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
216
196
|
outputs = self.model(
|
@@ -1,26 +1,22 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
6
|
-
from transformers.models.phi3.modeling_phi3 import
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
from transformers.utils import (
|
11
|
-
add_start_docstrings_to_model_forward,
|
12
|
-
replace_return_docstrings,
|
13
|
-
)
|
10
|
+
from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
|
11
|
+
from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
13
|
+
from transformers.utils import replace_return_docstrings
|
14
14
|
|
15
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
17
|
-
)
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
18
16
|
|
19
17
|
|
20
18
|
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
21
|
-
@replace_return_docstrings(
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
23
|
-
)
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
24
20
|
def lce_forward_deprecated(
|
25
21
|
self,
|
26
22
|
input_ids: torch.LongTensor = None,
|
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
|
|
64
60
|
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
65
61
|
```"""
|
66
62
|
|
67
|
-
output_attentions =
|
68
|
-
output_attentions
|
69
|
-
if output_attentions is not None
|
70
|
-
else self.config.output_attentions
|
71
|
-
)
|
63
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
72
64
|
output_hidden_states = (
|
73
|
-
output_hidden_states
|
74
|
-
if output_hidden_states is not None
|
75
|
-
else self.config.output_hidden_states
|
76
|
-
)
|
77
|
-
return_dict = (
|
78
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
65
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
79
66
|
)
|
67
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
80
68
|
|
81
69
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
82
70
|
outputs = self.model(
|
@@ -138,9 +126,7 @@ def lce_forward_deprecated(
|
|
138
126
|
|
139
127
|
|
140
128
|
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
141
|
-
@replace_return_docstrings(
|
142
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
143
|
-
)
|
129
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
144
130
|
def lce_forward(
|
145
131
|
self,
|
146
132
|
input_ids: torch.LongTensor = None,
|
@@ -202,19 +188,11 @@ def lce_forward(
|
|
202
188
|
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
|
203
189
|
)
|
204
190
|
|
205
|
-
output_attentions =
|
206
|
-
output_attentions
|
207
|
-
if output_attentions is not None
|
208
|
-
else self.config.output_attentions
|
209
|
-
)
|
191
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
210
192
|
output_hidden_states = (
|
211
|
-
output_hidden_states
|
212
|
-
if output_hidden_states is not None
|
213
|
-
else self.config.output_hidden_states
|
214
|
-
)
|
215
|
-
return_dict = (
|
216
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
193
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
217
194
|
)
|
195
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
218
196
|
|
219
197
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
220
198
|
outputs = self.model(
|
@@ -1,26 +1,22 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
6
|
-
from transformers.models.qwen2.modeling_qwen2 import
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
from transformers.utils import (
|
11
|
-
add_start_docstrings_to_model_forward,
|
12
|
-
replace_return_docstrings,
|
13
|
-
)
|
10
|
+
from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
|
11
|
+
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
13
|
+
from transformers.utils import replace_return_docstrings
|
14
14
|
|
15
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
17
|
-
)
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
18
16
|
|
19
17
|
|
20
18
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
21
|
-
@replace_return_docstrings(
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
23
|
-
)
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
24
20
|
def lce_forward_deprecated(
|
25
21
|
self,
|
26
22
|
input_ids: torch.LongTensor = None,
|
@@ -63,19 +59,11 @@ def lce_forward_deprecated(
|
|
63
59
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
64
60
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
65
61
|
```"""
|
66
|
-
output_attentions =
|
67
|
-
output_attentions
|
68
|
-
if output_attentions is not None
|
69
|
-
else self.config.output_attentions
|
70
|
-
)
|
62
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
71
63
|
output_hidden_states = (
|
72
|
-
output_hidden_states
|
73
|
-
if output_hidden_states is not None
|
74
|
-
else self.config.output_hidden_states
|
75
|
-
)
|
76
|
-
return_dict = (
|
77
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
64
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
78
65
|
)
|
66
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
79
67
|
|
80
68
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
81
69
|
outputs = self.model(
|
@@ -137,9 +125,7 @@ def lce_forward_deprecated(
|
|
137
125
|
|
138
126
|
|
139
127
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
140
|
-
@replace_return_docstrings(
|
141
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
142
|
-
)
|
128
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
143
129
|
def lce_forward(
|
144
130
|
self,
|
145
131
|
input_ids: torch.LongTensor = None,
|
@@ -187,19 +173,11 @@ def lce_forward(
|
|
187
173
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
188
174
|
```"""
|
189
175
|
|
190
|
-
output_attentions =
|
191
|
-
output_attentions
|
192
|
-
if output_attentions is not None
|
193
|
-
else self.config.output_attentions
|
194
|
-
)
|
176
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
195
177
|
output_hidden_states = (
|
196
|
-
output_hidden_states
|
197
|
-
if output_hidden_states is not None
|
198
|
-
else self.config.output_hidden_states
|
199
|
-
)
|
200
|
-
return_dict = (
|
201
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
178
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
202
179
|
)
|
180
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
203
181
|
|
204
182
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
205
183
|
outputs = self.model(
|
@@ -1,28 +1,24 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from packaging import version
|
5
9
|
from torch.nn import CrossEntropyLoss
|
6
10
|
from transformers import __version__ as transformers_version
|
7
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
from transformers.utils import (
|
13
|
-
add_start_docstrings_to_model_forward,
|
14
|
-
replace_return_docstrings,
|
15
|
-
)
|
11
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
|
12
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
|
13
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
14
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
15
|
+
from transformers.utils import replace_return_docstrings
|
16
16
|
|
17
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
18
|
-
LigerFusedLinearCrossEntropyLoss,
|
19
|
-
)
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
20
18
|
|
21
19
|
|
22
20
|
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
23
|
-
@replace_return_docstrings(
|
24
|
-
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
25
|
-
)
|
21
|
+
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
26
22
|
def lce_forward(
|
27
23
|
self,
|
28
24
|
input_ids: torch.LongTensor = None,
|
@@ -82,19 +78,11 @@ def lce_forward(
|
|
82
78
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
83
79
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
84
80
|
```"""
|
85
|
-
output_attentions =
|
86
|
-
output_attentions
|
87
|
-
if output_attentions is not None
|
88
|
-
else self.config.output_attentions
|
89
|
-
)
|
81
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
90
82
|
output_hidden_states = (
|
91
|
-
output_hidden_states
|
92
|
-
if output_hidden_states is not None
|
93
|
-
else self.config.output_hidden_states
|
94
|
-
)
|
95
|
-
return_dict = (
|
96
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
83
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
97
84
|
)
|
85
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
98
86
|
|
99
87
|
if inputs_embeds is None:
|
100
88
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
@@ -144,9 +132,7 @@ def lce_forward(
|
|
144
132
|
# transformers and leads to failed tests or users noticing differences in results.
|
145
133
|
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
146
134
|
if position_ids is None and input_ids is not None:
|
147
|
-
position_ids, _ = self.get_rope_index(
|
148
|
-
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
149
|
-
)
|
135
|
+
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
150
136
|
|
151
137
|
outputs = self.model(
|
152
138
|
input_ids=None,
|