liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +5 -11
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +8 -24
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,22 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
6
|
-
from transformers.models.phi3.modeling_phi3 import
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
from transformers.utils import (
|
11
|
-
add_start_docstrings_to_model_forward,
|
12
|
-
replace_return_docstrings,
|
13
|
-
)
|
10
|
+
from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
|
11
|
+
from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
13
|
+
from transformers.utils import replace_return_docstrings
|
14
14
|
|
15
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
17
|
-
)
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
18
16
|
|
19
17
|
|
20
18
|
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
21
|
-
@replace_return_docstrings(
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
23
|
-
)
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
24
20
|
def lce_forward_deprecated(
|
25
21
|
self,
|
26
22
|
input_ids: torch.LongTensor = None,
|
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
|
|
64
60
|
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
65
61
|
```"""
|
66
62
|
|
67
|
-
output_attentions =
|
68
|
-
output_attentions
|
69
|
-
if output_attentions is not None
|
70
|
-
else self.config.output_attentions
|
71
|
-
)
|
63
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
72
64
|
output_hidden_states = (
|
73
|
-
output_hidden_states
|
74
|
-
if output_hidden_states is not None
|
75
|
-
else self.config.output_hidden_states
|
76
|
-
)
|
77
|
-
return_dict = (
|
78
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
65
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
79
66
|
)
|
67
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
80
68
|
|
81
69
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
82
70
|
outputs = self.model(
|
@@ -138,9 +126,7 @@ def lce_forward_deprecated(
|
|
138
126
|
|
139
127
|
|
140
128
|
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
141
|
-
@replace_return_docstrings(
|
142
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
143
|
-
)
|
129
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
144
130
|
def lce_forward(
|
145
131
|
self,
|
146
132
|
input_ids: torch.LongTensor = None,
|
@@ -202,19 +188,11 @@ def lce_forward(
|
|
202
188
|
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
|
203
189
|
)
|
204
190
|
|
205
|
-
output_attentions =
|
206
|
-
output_attentions
|
207
|
-
if output_attentions is not None
|
208
|
-
else self.config.output_attentions
|
209
|
-
)
|
191
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
210
192
|
output_hidden_states = (
|
211
|
-
output_hidden_states
|
212
|
-
if output_hidden_states is not None
|
213
|
-
else self.config.output_hidden_states
|
214
|
-
)
|
215
|
-
return_dict = (
|
216
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
193
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
217
194
|
)
|
195
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
218
196
|
|
219
197
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
220
198
|
outputs = self.model(
|
@@ -1,26 +1,22 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
6
|
-
from transformers.models.qwen2.modeling_qwen2 import
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
from transformers.utils import (
|
11
|
-
add_start_docstrings_to_model_forward,
|
12
|
-
replace_return_docstrings,
|
13
|
-
)
|
10
|
+
from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
|
11
|
+
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
13
|
+
from transformers.utils import replace_return_docstrings
|
14
14
|
|
15
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
17
|
-
)
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
18
16
|
|
19
17
|
|
20
18
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
21
|
-
@replace_return_docstrings(
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
23
|
-
)
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
24
20
|
def lce_forward_deprecated(
|
25
21
|
self,
|
26
22
|
input_ids: torch.LongTensor = None,
|
@@ -63,19 +59,11 @@ def lce_forward_deprecated(
|
|
63
59
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
64
60
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
65
61
|
```"""
|
66
|
-
output_attentions =
|
67
|
-
output_attentions
|
68
|
-
if output_attentions is not None
|
69
|
-
else self.config.output_attentions
|
70
|
-
)
|
62
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
71
63
|
output_hidden_states = (
|
72
|
-
output_hidden_states
|
73
|
-
if output_hidden_states is not None
|
74
|
-
else self.config.output_hidden_states
|
75
|
-
)
|
76
|
-
return_dict = (
|
77
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
64
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
78
65
|
)
|
66
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
79
67
|
|
80
68
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
81
69
|
outputs = self.model(
|
@@ -137,9 +125,7 @@ def lce_forward_deprecated(
|
|
137
125
|
|
138
126
|
|
139
127
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
140
|
-
@replace_return_docstrings(
|
141
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
142
|
-
)
|
128
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
143
129
|
def lce_forward(
|
144
130
|
self,
|
145
131
|
input_ids: torch.LongTensor = None,
|
@@ -187,19 +173,11 @@ def lce_forward(
|
|
187
173
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
188
174
|
```"""
|
189
175
|
|
190
|
-
output_attentions =
|
191
|
-
output_attentions
|
192
|
-
if output_attentions is not None
|
193
|
-
else self.config.output_attentions
|
194
|
-
)
|
176
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
195
177
|
output_hidden_states = (
|
196
|
-
output_hidden_states
|
197
|
-
if output_hidden_states is not None
|
198
|
-
else self.config.output_hidden_states
|
199
|
-
)
|
200
|
-
return_dict = (
|
201
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
178
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
202
179
|
)
|
180
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
203
181
|
|
204
182
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
205
183
|
outputs = self.model(
|
@@ -1,28 +1,24 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from packaging import version
|
5
9
|
from torch.nn import CrossEntropyLoss
|
6
10
|
from transformers import __version__ as transformers_version
|
7
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
from transformers.utils import (
|
13
|
-
add_start_docstrings_to_model_forward,
|
14
|
-
replace_return_docstrings,
|
15
|
-
)
|
11
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
|
12
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
|
13
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
14
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
15
|
+
from transformers.utils import replace_return_docstrings
|
16
16
|
|
17
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
18
|
-
LigerFusedLinearCrossEntropyLoss,
|
19
|
-
)
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
20
18
|
|
21
19
|
|
22
20
|
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
23
|
-
@replace_return_docstrings(
|
24
|
-
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
25
|
-
)
|
21
|
+
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
26
22
|
def lce_forward(
|
27
23
|
self,
|
28
24
|
input_ids: torch.LongTensor = None,
|
@@ -82,19 +78,11 @@ def lce_forward(
|
|
82
78
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
83
79
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
84
80
|
```"""
|
85
|
-
output_attentions =
|
86
|
-
output_attentions
|
87
|
-
if output_attentions is not None
|
88
|
-
else self.config.output_attentions
|
89
|
-
)
|
81
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
90
82
|
output_hidden_states = (
|
91
|
-
output_hidden_states
|
92
|
-
if output_hidden_states is not None
|
93
|
-
else self.config.output_hidden_states
|
94
|
-
)
|
95
|
-
return_dict = (
|
96
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
83
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
97
84
|
)
|
85
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
98
86
|
|
99
87
|
if inputs_embeds is None:
|
100
88
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
@@ -144,9 +132,7 @@ def lce_forward(
|
|
144
132
|
# transformers and leads to failed tests or users noticing differences in results.
|
145
133
|
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
146
134
|
if position_ids is None and input_ids is not None:
|
147
|
-
position_ids, _ = self.get_rope_index(
|
148
|
-
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
149
|
-
)
|
135
|
+
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
150
136
|
|
151
137
|
outputs = self.model(
|
152
138
|
input_ids=None,
|
@@ -1,9 +1,11 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
|
3
4
|
from functools import partial
|
4
5
|
from typing import Callable
|
5
6
|
|
6
7
|
import transformers
|
8
|
+
|
7
9
|
from packaging import version
|
8
10
|
from transformers import PreTrainedModel
|
9
11
|
|
@@ -12,38 +14,24 @@ from liger_kernel.transformers.functional import liger_cross_entropy
|
|
12
14
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
13
15
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
14
16
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
15
|
-
from liger_kernel.transformers.model.gemma import
|
16
|
-
lce_forward_deprecated as gemma_lce_forward_deprecated,
|
17
|
-
)
|
17
|
+
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
18
18
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
19
|
-
from liger_kernel.transformers.model.gemma2 import
|
20
|
-
lce_forward_deprecated as gemma2_lce_forward_deprected,
|
21
|
-
)
|
19
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
22
20
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
23
|
-
from liger_kernel.transformers.model.llama import
|
24
|
-
lce_forward_deprecated as llama_lce_forward_deprecated,
|
25
|
-
)
|
21
|
+
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
26
22
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
27
23
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
28
|
-
from liger_kernel.transformers.model.mixtral import
|
29
|
-
lce_forward_deprecated as mixtral_lce_forward_deprecated,
|
30
|
-
)
|
24
|
+
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
31
25
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
32
|
-
from liger_kernel.transformers.model.phi3 import
|
33
|
-
lce_forward_deprecated as phi3_lce_forward_deprecated,
|
34
|
-
)
|
26
|
+
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
35
27
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
36
|
-
from liger_kernel.transformers.model.qwen2 import
|
37
|
-
lce_forward_deprecated as qwen2_lce_forward_deprecated,
|
38
|
-
)
|
28
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
39
29
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
40
30
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
41
31
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
42
|
-
from liger_kernel.transformers.swiglu import
|
43
|
-
|
44
|
-
|
45
|
-
LigerSwiGLUMLP,
|
46
|
-
)
|
32
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
33
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
34
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
47
35
|
|
48
36
|
transformer_version = version.parse(transformers.__version__)
|
49
37
|
|
@@ -57,23 +45,17 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
57
45
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
58
46
|
|
59
47
|
|
60
|
-
def _patch_rms_norm_module(
|
61
|
-
module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
|
62
|
-
):
|
48
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
63
49
|
module.offset = offset
|
64
50
|
module.casting_mode = casting_mode
|
65
|
-
module.variance_epsilon = (
|
66
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
67
|
-
)
|
51
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
68
52
|
module.in_place = in_place
|
69
53
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
70
54
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
71
55
|
|
72
56
|
|
73
57
|
def _patch_layer_norm_module(module, eps=1e-6):
|
74
|
-
module.variance_epsilon = (
|
75
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
76
|
-
)
|
58
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
77
59
|
module.hidden_size = module.normalized_shape
|
78
60
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
79
61
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
@@ -145,9 +127,7 @@ def apply_liger_kernel_to_llama(
|
|
145
127
|
|
146
128
|
for decoder_layer in base_model.layers:
|
147
129
|
if swiglu:
|
148
|
-
_bind_method_to_module(
|
149
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
150
|
-
)
|
130
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
151
131
|
if rms_norm:
|
152
132
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
153
133
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -184,17 +164,13 @@ def apply_liger_kernel_to_mllama(
|
|
184
164
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
185
165
|
|
186
166
|
from transformers.models.mllama import modeling_mllama
|
187
|
-
from transformers.models.mllama.modeling_mllama import
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
MllamaVisionModel,
|
192
|
-
)
|
167
|
+
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
|
168
|
+
from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
|
169
|
+
from transformers.models.mllama.modeling_mllama import MllamaTextModel
|
170
|
+
from transformers.models.mllama.modeling_mllama import MllamaVisionModel
|
193
171
|
|
194
172
|
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
|
195
|
-
from liger_kernel.transformers.model.mllama import
|
196
|
-
lce_forward_deprecated as mllama_lce_forward_deprecated,
|
197
|
-
)
|
173
|
+
from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
|
198
174
|
|
199
175
|
if rope:
|
200
176
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
@@ -241,9 +217,7 @@ def apply_liger_kernel_to_mllama(
|
|
241
217
|
_patch_rms_norm_module(text_model.norm)
|
242
218
|
for decoder_layer in text_model.layers:
|
243
219
|
if swiglu:
|
244
|
-
_bind_method_to_module(
|
245
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
246
|
-
)
|
220
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
247
221
|
if rms_norm:
|
248
222
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
249
223
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -317,9 +291,7 @@ def apply_liger_kernel_to_mistral(
|
|
317
291
|
|
318
292
|
for decoder_layer in base_model.layers:
|
319
293
|
if swiglu:
|
320
|
-
_bind_method_to_module(
|
321
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
322
|
-
)
|
294
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
323
295
|
if rms_norm:
|
324
296
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
325
297
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -391,9 +363,7 @@ def apply_liger_kernel_to_mixtral(
|
|
391
363
|
for decoder_layer in base_model.layers:
|
392
364
|
if swiglu:
|
393
365
|
for expert in decoder_layer.block_sparse_moe.experts:
|
394
|
-
_bind_method_to_module(
|
395
|
-
expert, "forward", LigerBlockSparseTop2MLP.forward
|
396
|
-
)
|
366
|
+
_bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
|
397
367
|
if rms_norm:
|
398
368
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
399
369
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -431,12 +401,8 @@ def apply_liger_kernel_to_gemma(
|
|
431
401
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
432
402
|
|
433
403
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
434
|
-
LigerRMSNormForGemma = partial(
|
435
|
-
|
436
|
-
)
|
437
|
-
_patch_rms_norm_module_for_gemma = partial(
|
438
|
-
_patch_rms_norm_module, casting_mode="gemma", offset=1.0
|
439
|
-
)
|
404
|
+
LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
|
405
|
+
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
440
406
|
|
441
407
|
if rope:
|
442
408
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
@@ -471,9 +437,7 @@ def apply_liger_kernel_to_gemma(
|
|
471
437
|
|
472
438
|
for decoder_layer in base_model.layers:
|
473
439
|
if geglu:
|
474
|
-
_bind_method_to_module(
|
475
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
476
|
-
)
|
440
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
477
441
|
if rms_norm:
|
478
442
|
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
479
443
|
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
@@ -510,9 +474,7 @@ def apply_liger_kernel_to_gemma2(
|
|
510
474
|
from transformers.models.gemma2 import modeling_gemma2
|
511
475
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
512
476
|
|
513
|
-
LigerRMSNormForGemma2 = partial(
|
514
|
-
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
|
515
|
-
)
|
477
|
+
LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
|
516
478
|
_patch_rms_norm_module_for_gemma2 = partial(
|
517
479
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
518
480
|
)
|
@@ -551,20 +513,12 @@ def apply_liger_kernel_to_gemma2(
|
|
551
513
|
|
552
514
|
for decoder_layer in base_model.layers:
|
553
515
|
if geglu:
|
554
|
-
_bind_method_to_module(
|
555
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
556
|
-
)
|
516
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
557
517
|
if rms_norm:
|
558
518
|
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
559
|
-
_patch_rms_norm_module_for_gemma2(
|
560
|
-
|
561
|
-
)
|
562
|
-
_patch_rms_norm_module_for_gemma2(
|
563
|
-
decoder_layer.pre_feedforward_layernorm
|
564
|
-
)
|
565
|
-
_patch_rms_norm_module_for_gemma2(
|
566
|
-
decoder_layer.post_feedforward_layernorm
|
567
|
-
)
|
519
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
520
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
|
521
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
568
522
|
|
569
523
|
|
570
524
|
def apply_liger_kernel_to_qwen2(
|
@@ -633,9 +587,7 @@ def apply_liger_kernel_to_qwen2(
|
|
633
587
|
|
634
588
|
for decoder_layer in base_model.layers:
|
635
589
|
if swiglu:
|
636
|
-
_bind_method_to_module(
|
637
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
638
|
-
)
|
590
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
639
591
|
if rms_norm:
|
640
592
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
641
593
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -674,14 +626,10 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
674
626
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
675
627
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
676
628
|
|
677
|
-
from liger_kernel.transformers.model.qwen2_vl import
|
678
|
-
lce_forward as qwen2_vl_lce_forward,
|
679
|
-
)
|
629
|
+
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
680
630
|
|
681
631
|
if rope:
|
682
|
-
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb =
|
683
|
-
liger_multimodal_rotary_pos_emb
|
684
|
-
)
|
632
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
685
633
|
if rms_norm:
|
686
634
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
687
635
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
@@ -712,9 +660,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
712
660
|
_patch_rms_norm_module(base_model.norm)
|
713
661
|
for decoder_layer in base_model.layers:
|
714
662
|
if swiglu:
|
715
|
-
_bind_method_to_module(
|
716
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
717
|
-
)
|
663
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
718
664
|
if rms_norm:
|
719
665
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
720
666
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -783,9 +729,7 @@ def apply_liger_kernel_to_phi3(
|
|
783
729
|
|
784
730
|
for decoder_layer in base_model.layers:
|
785
731
|
if swiglu:
|
786
|
-
_bind_method_to_module(
|
787
|
-
decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
|
788
|
-
)
|
732
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
|
789
733
|
if rms_norm:
|
790
734
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
791
735
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -826,24 +770,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
|
|
826
770
|
return
|
827
771
|
|
828
772
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
829
|
-
logger.info(
|
830
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
831
|
-
)
|
773
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
832
774
|
return
|
833
775
|
|
834
776
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
835
777
|
apply_fn_signature = inspect.signature(apply_fn)
|
836
778
|
|
837
779
|
# Filter out the keyword arguments that are not supported by the apply function
|
838
|
-
applicable_kwargs = {
|
839
|
-
key: value
|
840
|
-
for key, value in kwargs.items()
|
841
|
-
if key in apply_fn_signature.parameters
|
842
|
-
}
|
780
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
843
781
|
|
844
|
-
logger.info(
|
845
|
-
f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
|
846
|
-
)
|
782
|
+
logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
|
847
783
|
|
848
784
|
# Assume this is invoked pre-model initialization, so we only need to patch transformers code
|
849
785
|
apply_fn(**applicable_kwargs)
|
@@ -857,20 +793,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
857
793
|
- model: the model instance to apply Liger kernels to
|
858
794
|
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
859
795
|
"""
|
860
|
-
model_type = getattr(model, "config", None) and getattr(
|
861
|
-
model.config, "model_type", None
|
862
|
-
)
|
796
|
+
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
863
797
|
|
864
798
|
if not model_type:
|
865
|
-
logger.info(
|
866
|
-
"Model type could not be determined from model config. No Liger kernels will be applied."
|
867
|
-
)
|
799
|
+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
|
868
800
|
return
|
869
801
|
|
870
802
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
871
|
-
logger.info(
|
872
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
873
|
-
)
|
803
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
874
804
|
return
|
875
805
|
|
876
806
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
@@ -878,11 +808,7 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
878
808
|
apply_fn_signature = inspect.signature(apply_fn)
|
879
809
|
|
880
810
|
# Filter out the keyword arguments that are not supported by the apply function
|
881
|
-
applicable_kwargs = {
|
882
|
-
key: value
|
883
|
-
for key, value in kwargs.items()
|
884
|
-
if key in apply_fn_signature.parameters
|
885
|
-
}
|
811
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
886
812
|
logger.info(
|
887
813
|
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
888
814
|
)
|
@@ -19,9 +19,7 @@ class LigerRMSNorm(nn.Module):
|
|
19
19
|
"ones",
|
20
20
|
"zeros",
|
21
21
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
22
|
-
self.weight = nn.Parameter(
|
23
|
-
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
24
|
-
)
|
22
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
25
23
|
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
|
26
24
|
eps,
|
27
25
|
offset,
|
@@ -40,4 +38,6 @@ class LigerRMSNorm(nn.Module):
|
|
40
38
|
)
|
41
39
|
|
42
40
|
def extra_repr(self):
|
43
|
-
return
|
41
|
+
return (
|
42
|
+
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
|
43
|
+
)
|
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
|
|
16
16
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
17
17
|
|
18
18
|
def forward(self, x):
|
19
|
-
|
20
|
-
return self.down_proj(
|
21
|
-
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
22
|
-
)
|
19
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
23
20
|
|
24
21
|
|
25
22
|
class LigerBlockSparseTop2MLP(nn.Module):
|
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
|
|
36
33
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
37
34
|
|
38
35
|
def forward(self, x):
|
39
|
-
|
40
36
|
return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
|
41
37
|
|
42
38
|
|
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
51
47
|
self.config = config
|
52
48
|
self.hidden_size = config.hidden_size
|
53
49
|
self.intermediate_size = config.intermediate_size
|
54
|
-
self.gate_up_proj = nn.Linear(
|
55
|
-
self.hidden_size, 2 * self.intermediate_size, bias=False
|
56
|
-
)
|
50
|
+
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
57
51
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
58
52
|
if config.hidden_act not in ["silu", "swish"]:
|
59
53
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
@@ -1,6 +1,4 @@
|
|
1
1
|
try:
|
2
|
-
from liger_kernel.transformers.trainer.orpo_trainer import
|
3
|
-
LigerORPOTrainer,
|
4
|
-
)
|
2
|
+
from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
|
5
3
|
except ImportError:
|
6
4
|
raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
|