liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +5 -11
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +8 -24
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,22 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
6
|
-
from transformers.models.phi3.modeling_phi3 import
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
from transformers.utils import (
|
11
|
-
add_start_docstrings_to_model_forward,
|
12
|
-
replace_return_docstrings,
|
13
|
-
)
|
10
|
+
from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
|
11
|
+
from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
13
|
+
from transformers.utils import replace_return_docstrings
|
14
14
|
|
15
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
17
|
-
)
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
18
16
|
|
19
17
|
|
20
18
|
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
21
|
-
@replace_return_docstrings(
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
23
|
-
)
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
24
20
|
def lce_forward_deprecated(
|
25
21
|
self,
|
26
22
|
input_ids: torch.LongTensor = None,
|
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
|
|
64
60
|
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
65
61
|
```"""
|
66
62
|
|
67
|
-
output_attentions =
|
68
|
-
output_attentions
|
69
|
-
if output_attentions is not None
|
70
|
-
else self.config.output_attentions
|
71
|
-
)
|
63
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
72
64
|
output_hidden_states = (
|
73
|
-
output_hidden_states
|
74
|
-
if output_hidden_states is not None
|
75
|
-
else self.config.output_hidden_states
|
76
|
-
)
|
77
|
-
return_dict = (
|
78
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
65
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
79
66
|
)
|
67
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
80
68
|
|
81
69
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
82
70
|
outputs = self.model(
|
@@ -138,9 +126,7 @@ def lce_forward_deprecated(
|
|
138
126
|
|
139
127
|
|
140
128
|
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
141
|
-
@replace_return_docstrings(
|
142
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
143
|
-
)
|
129
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
144
130
|
def lce_forward(
|
145
131
|
self,
|
146
132
|
input_ids: torch.LongTensor = None,
|
@@ -202,19 +188,11 @@ def lce_forward(
|
|
202
188
|
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
|
203
189
|
)
|
204
190
|
|
205
|
-
output_attentions =
|
206
|
-
output_attentions
|
207
|
-
if output_attentions is not None
|
208
|
-
else self.config.output_attentions
|
209
|
-
)
|
191
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
210
192
|
output_hidden_states = (
|
211
|
-
output_hidden_states
|
212
|
-
if output_hidden_states is not None
|
213
|
-
else self.config.output_hidden_states
|
214
|
-
)
|
215
|
-
return_dict = (
|
216
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
193
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
217
194
|
)
|
195
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
218
196
|
|
219
197
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
220
198
|
outputs = self.model(
|
@@ -1,26 +1,22 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from torch.nn import CrossEntropyLoss
|
5
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
6
|
-
from transformers.models.qwen2.modeling_qwen2 import
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
from transformers.utils import (
|
11
|
-
add_start_docstrings_to_model_forward,
|
12
|
-
replace_return_docstrings,
|
13
|
-
)
|
10
|
+
from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
|
11
|
+
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
13
|
+
from transformers.utils import replace_return_docstrings
|
14
14
|
|
15
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
17
|
-
)
|
15
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
18
16
|
|
19
17
|
|
20
18
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
21
|
-
@replace_return_docstrings(
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
23
|
-
)
|
19
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
24
20
|
def lce_forward_deprecated(
|
25
21
|
self,
|
26
22
|
input_ids: torch.LongTensor = None,
|
@@ -63,19 +59,11 @@ def lce_forward_deprecated(
|
|
63
59
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
64
60
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
65
61
|
```"""
|
66
|
-
output_attentions =
|
67
|
-
output_attentions
|
68
|
-
if output_attentions is not None
|
69
|
-
else self.config.output_attentions
|
70
|
-
)
|
62
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
71
63
|
output_hidden_states = (
|
72
|
-
output_hidden_states
|
73
|
-
if output_hidden_states is not None
|
74
|
-
else self.config.output_hidden_states
|
75
|
-
)
|
76
|
-
return_dict = (
|
77
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
64
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
78
65
|
)
|
66
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
79
67
|
|
80
68
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
81
69
|
outputs = self.model(
|
@@ -137,9 +125,7 @@ def lce_forward_deprecated(
|
|
137
125
|
|
138
126
|
|
139
127
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
140
|
-
@replace_return_docstrings(
|
141
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
142
|
-
)
|
128
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
143
129
|
def lce_forward(
|
144
130
|
self,
|
145
131
|
input_ids: torch.LongTensor = None,
|
@@ -187,19 +173,11 @@ def lce_forward(
|
|
187
173
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
188
174
|
```"""
|
189
175
|
|
190
|
-
output_attentions =
|
191
|
-
output_attentions
|
192
|
-
if output_attentions is not None
|
193
|
-
else self.config.output_attentions
|
194
|
-
)
|
176
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
195
177
|
output_hidden_states = (
|
196
|
-
output_hidden_states
|
197
|
-
if output_hidden_states is not None
|
198
|
-
else self.config.output_hidden_states
|
199
|
-
)
|
200
|
-
return_dict = (
|
201
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
178
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
202
179
|
)
|
180
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
203
181
|
|
204
182
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
205
183
|
outputs = self.model(
|
@@ -1,28 +1,24 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
2
5
|
|
3
6
|
import torch
|
7
|
+
|
4
8
|
from packaging import version
|
5
9
|
from torch.nn import CrossEntropyLoss
|
6
10
|
from transformers import __version__ as transformers_version
|
7
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
from transformers.utils import (
|
13
|
-
add_start_docstrings_to_model_forward,
|
14
|
-
replace_return_docstrings,
|
15
|
-
)
|
11
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
|
12
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
|
13
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
14
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
15
|
+
from transformers.utils import replace_return_docstrings
|
16
16
|
|
17
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
18
|
-
LigerFusedLinearCrossEntropyLoss,
|
19
|
-
)
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
20
18
|
|
21
19
|
|
22
20
|
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
23
|
-
@replace_return_docstrings(
|
24
|
-
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
25
|
-
)
|
21
|
+
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
26
22
|
def lce_forward(
|
27
23
|
self,
|
28
24
|
input_ids: torch.LongTensor = None,
|
@@ -82,19 +78,11 @@ def lce_forward(
|
|
82
78
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
83
79
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
84
80
|
```"""
|
85
|
-
output_attentions =
|
86
|
-
output_attentions
|
87
|
-
if output_attentions is not None
|
88
|
-
else self.config.output_attentions
|
89
|
-
)
|
81
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
90
82
|
output_hidden_states = (
|
91
|
-
output_hidden_states
|
92
|
-
if output_hidden_states is not None
|
93
|
-
else self.config.output_hidden_states
|
94
|
-
)
|
95
|
-
return_dict = (
|
96
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
83
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
97
84
|
)
|
85
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
98
86
|
|
99
87
|
if inputs_embeds is None:
|
100
88
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
@@ -144,9 +132,7 @@ def lce_forward(
|
|
144
132
|
# transformers and leads to failed tests or users noticing differences in results.
|
145
133
|
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
146
134
|
if position_ids is None and input_ids is not None:
|
147
|
-
position_ids, _ = self.get_rope_index(
|
148
|
-
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
149
|
-
)
|
135
|
+
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
150
136
|
|
151
137
|
outputs = self.model(
|
152
138
|
input_ids=None,
|
@@ -1,9 +1,11 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
|
3
4
|
from functools import partial
|
4
5
|
from typing import Callable
|
5
6
|
|
6
7
|
import transformers
|
8
|
+
|
7
9
|
from packaging import version
|
8
10
|
from transformers import PreTrainedModel
|
9
11
|
|
@@ -12,38 +14,24 @@ from liger_kernel.transformers.functional import liger_cross_entropy
|
|
12
14
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
13
15
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
14
16
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
15
|
-
from liger_kernel.transformers.model.gemma import
|
16
|
-
lce_forward_deprecated as gemma_lce_forward_deprecated,
|
17
|
-
)
|
17
|
+
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
18
18
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
19
|
-
from liger_kernel.transformers.model.gemma2 import
|
20
|
-
lce_forward_deprecated as gemma2_lce_forward_deprected,
|
21
|
-
)
|
19
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
22
20
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
23
|
-
from liger_kernel.transformers.model.llama import
|
24
|
-
lce_forward_deprecated as llama_lce_forward_deprecated,
|
25
|
-
)
|
21
|
+
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
26
22
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
27
23
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
28
|
-
from liger_kernel.transformers.model.mixtral import
|
29
|
-
lce_forward_deprecated as mixtral_lce_forward_deprecated,
|
30
|
-
)
|
24
|
+
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
31
25
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
32
|
-
from liger_kernel.transformers.model.phi3 import
|
33
|
-
lce_forward_deprecated as phi3_lce_forward_deprecated,
|
34
|
-
)
|
26
|
+
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
35
27
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
36
|
-
from liger_kernel.transformers.model.qwen2 import
|
37
|
-
lce_forward_deprecated as qwen2_lce_forward_deprecated,
|
38
|
-
)
|
28
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
39
29
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
40
30
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
41
31
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
42
|
-
from liger_kernel.transformers.swiglu import
|
43
|
-
|
44
|
-
|
45
|
-
LigerSwiGLUMLP,
|
46
|
-
)
|
32
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
33
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
34
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
47
35
|
|
48
36
|
transformer_version = version.parse(transformers.__version__)
|
49
37
|
|
@@ -57,23 +45,17 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
57
45
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
58
46
|
|
59
47
|
|
60
|
-
def _patch_rms_norm_module(
|
61
|
-
module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
|
62
|
-
):
|
48
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
63
49
|
module.offset = offset
|
64
50
|
module.casting_mode = casting_mode
|
65
|
-
module.variance_epsilon = (
|
66
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
67
|
-
)
|
51
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
68
52
|
module.in_place = in_place
|
69
53
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
70
54
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
71
55
|
|
72
56
|
|
73
57
|
def _patch_layer_norm_module(module, eps=1e-6):
|
74
|
-
module.variance_epsilon = (
|
75
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
76
|
-
)
|
58
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
77
59
|
module.hidden_size = module.normalized_shape
|
78
60
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
79
61
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
@@ -145,9 +127,7 @@ def apply_liger_kernel_to_llama(
|
|
145
127
|
|
146
128
|
for decoder_layer in base_model.layers:
|
147
129
|
if swiglu:
|
148
|
-
_bind_method_to_module(
|
149
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
150
|
-
)
|
130
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
151
131
|
if rms_norm:
|
152
132
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
153
133
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -184,17 +164,13 @@ def apply_liger_kernel_to_mllama(
|
|
184
164
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
185
165
|
|
186
166
|
from transformers.models.mllama import modeling_mllama
|
187
|
-
from transformers.models.mllama.modeling_mllama import
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
MllamaVisionModel,
|
192
|
-
)
|
167
|
+
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
|
168
|
+
from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
|
169
|
+
from transformers.models.mllama.modeling_mllama import MllamaTextModel
|
170
|
+
from transformers.models.mllama.modeling_mllama import MllamaVisionModel
|
193
171
|
|
194
172
|
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
|
195
|
-
from liger_kernel.transformers.model.mllama import
|
196
|
-
lce_forward_deprecated as mllama_lce_forward_deprecated,
|
197
|
-
)
|
173
|
+
from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
|
198
174
|
|
199
175
|
if rope:
|
200
176
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
@@ -241,9 +217,7 @@ def apply_liger_kernel_to_mllama(
|
|
241
217
|
_patch_rms_norm_module(text_model.norm)
|
242
218
|
for decoder_layer in text_model.layers:
|
243
219
|
if swiglu:
|
244
|
-
_bind_method_to_module(
|
245
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
246
|
-
)
|
220
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
247
221
|
if rms_norm:
|
248
222
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
249
223
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -317,9 +291,7 @@ def apply_liger_kernel_to_mistral(
|
|
317
291
|
|
318
292
|
for decoder_layer in base_model.layers:
|
319
293
|
if swiglu:
|
320
|
-
_bind_method_to_module(
|
321
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
322
|
-
)
|
294
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
323
295
|
if rms_norm:
|
324
296
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
325
297
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -391,9 +363,7 @@ def apply_liger_kernel_to_mixtral(
|
|
391
363
|
for decoder_layer in base_model.layers:
|
392
364
|
if swiglu:
|
393
365
|
for expert in decoder_layer.block_sparse_moe.experts:
|
394
|
-
_bind_method_to_module(
|
395
|
-
expert, "forward", LigerBlockSparseTop2MLP.forward
|
396
|
-
)
|
366
|
+
_bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
|
397
367
|
if rms_norm:
|
398
368
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
399
369
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -431,12 +401,8 @@ def apply_liger_kernel_to_gemma(
|
|
431
401
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
432
402
|
|
433
403
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
434
|
-
LigerRMSNormForGemma = partial(
|
435
|
-
|
436
|
-
)
|
437
|
-
_patch_rms_norm_module_for_gemma = partial(
|
438
|
-
_patch_rms_norm_module, casting_mode="gemma", offset=1.0
|
439
|
-
)
|
404
|
+
LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
|
405
|
+
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
440
406
|
|
441
407
|
if rope:
|
442
408
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
@@ -471,9 +437,7 @@ def apply_liger_kernel_to_gemma(
|
|
471
437
|
|
472
438
|
for decoder_layer in base_model.layers:
|
473
439
|
if geglu:
|
474
|
-
_bind_method_to_module(
|
475
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
476
|
-
)
|
440
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
477
441
|
if rms_norm:
|
478
442
|
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
479
443
|
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
@@ -510,9 +474,7 @@ def apply_liger_kernel_to_gemma2(
|
|
510
474
|
from transformers.models.gemma2 import modeling_gemma2
|
511
475
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
512
476
|
|
513
|
-
LigerRMSNormForGemma2 = partial(
|
514
|
-
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
|
515
|
-
)
|
477
|
+
LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
|
516
478
|
_patch_rms_norm_module_for_gemma2 = partial(
|
517
479
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
518
480
|
)
|
@@ -551,20 +513,12 @@ def apply_liger_kernel_to_gemma2(
|
|
551
513
|
|
552
514
|
for decoder_layer in base_model.layers:
|
553
515
|
if geglu:
|
554
|
-
_bind_method_to_module(
|
555
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
556
|
-
)
|
516
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
557
517
|
if rms_norm:
|
558
518
|
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
559
|
-
_patch_rms_norm_module_for_gemma2(
|
560
|
-
|
561
|
-
)
|
562
|
-
_patch_rms_norm_module_for_gemma2(
|
563
|
-
decoder_layer.pre_feedforward_layernorm
|
564
|
-
)
|
565
|
-
_patch_rms_norm_module_for_gemma2(
|
566
|
-
decoder_layer.post_feedforward_layernorm
|
567
|
-
)
|
519
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
520
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
|
521
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
568
522
|
|
569
523
|
|
570
524
|
def apply_liger_kernel_to_qwen2(
|
@@ -633,9 +587,7 @@ def apply_liger_kernel_to_qwen2(
|
|
633
587
|
|
634
588
|
for decoder_layer in base_model.layers:
|
635
589
|
if swiglu:
|
636
|
-
_bind_method_to_module(
|
637
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
638
|
-
)
|
590
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
639
591
|
if rms_norm:
|
640
592
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
641
593
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -674,14 +626,10 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
674
626
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
675
627
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
676
628
|
|
677
|
-
from liger_kernel.transformers.model.qwen2_vl import
|
678
|
-
lce_forward as qwen2_vl_lce_forward,
|
679
|
-
)
|
629
|
+
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
680
630
|
|
681
631
|
if rope:
|
682
|
-
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb =
|
683
|
-
liger_multimodal_rotary_pos_emb
|
684
|
-
)
|
632
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
685
633
|
if rms_norm:
|
686
634
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
687
635
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
@@ -712,9 +660,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
712
660
|
_patch_rms_norm_module(base_model.norm)
|
713
661
|
for decoder_layer in base_model.layers:
|
714
662
|
if swiglu:
|
715
|
-
_bind_method_to_module(
|
716
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
717
|
-
)
|
663
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
718
664
|
if rms_norm:
|
719
665
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
720
666
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -783,9 +729,7 @@ def apply_liger_kernel_to_phi3(
|
|
783
729
|
|
784
730
|
for decoder_layer in base_model.layers:
|
785
731
|
if swiglu:
|
786
|
-
_bind_method_to_module(
|
787
|
-
decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
|
788
|
-
)
|
732
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
|
789
733
|
if rms_norm:
|
790
734
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
791
735
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -826,24 +770,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
|
|
826
770
|
return
|
827
771
|
|
828
772
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
829
|
-
logger.info(
|
830
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
831
|
-
)
|
773
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
832
774
|
return
|
833
775
|
|
834
776
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
835
777
|
apply_fn_signature = inspect.signature(apply_fn)
|
836
778
|
|
837
779
|
# Filter out the keyword arguments that are not supported by the apply function
|
838
|
-
applicable_kwargs = {
|
839
|
-
key: value
|
840
|
-
for key, value in kwargs.items()
|
841
|
-
if key in apply_fn_signature.parameters
|
842
|
-
}
|
780
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
843
781
|
|
844
|
-
logger.info(
|
845
|
-
f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
|
846
|
-
)
|
782
|
+
logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
|
847
783
|
|
848
784
|
# Assume this is invoked pre-model initialization, so we only need to patch transformers code
|
849
785
|
apply_fn(**applicable_kwargs)
|
@@ -857,20 +793,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
857
793
|
- model: the model instance to apply Liger kernels to
|
858
794
|
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
859
795
|
"""
|
860
|
-
model_type = getattr(model, "config", None) and getattr(
|
861
|
-
model.config, "model_type", None
|
862
|
-
)
|
796
|
+
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
863
797
|
|
864
798
|
if not model_type:
|
865
|
-
logger.info(
|
866
|
-
"Model type could not be determined from model config. No Liger kernels will be applied."
|
867
|
-
)
|
799
|
+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
|
868
800
|
return
|
869
801
|
|
870
802
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
871
|
-
logger.info(
|
872
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
873
|
-
)
|
803
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
874
804
|
return
|
875
805
|
|
876
806
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
@@ -878,11 +808,7 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
878
808
|
apply_fn_signature = inspect.signature(apply_fn)
|
879
809
|
|
880
810
|
# Filter out the keyword arguments that are not supported by the apply function
|
881
|
-
applicable_kwargs = {
|
882
|
-
key: value
|
883
|
-
for key, value in kwargs.items()
|
884
|
-
if key in apply_fn_signature.parameters
|
885
|
-
}
|
811
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
886
812
|
logger.info(
|
887
813
|
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
888
814
|
)
|
@@ -19,9 +19,7 @@ class LigerRMSNorm(nn.Module):
|
|
19
19
|
"ones",
|
20
20
|
"zeros",
|
21
21
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
22
|
-
self.weight = nn.Parameter(
|
23
|
-
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
24
|
-
)
|
22
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
25
23
|
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
|
26
24
|
eps,
|
27
25
|
offset,
|
@@ -40,4 +38,6 @@ class LigerRMSNorm(nn.Module):
|
|
40
38
|
)
|
41
39
|
|
42
40
|
def extra_repr(self):
|
43
|
-
return
|
41
|
+
return (
|
42
|
+
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
|
43
|
+
)
|
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
|
|
16
16
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
17
17
|
|
18
18
|
def forward(self, x):
|
19
|
-
|
20
|
-
return self.down_proj(
|
21
|
-
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
22
|
-
)
|
19
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
23
20
|
|
24
21
|
|
25
22
|
class LigerBlockSparseTop2MLP(nn.Module):
|
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
|
|
36
33
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
37
34
|
|
38
35
|
def forward(self, x):
|
39
|
-
|
40
36
|
return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
|
41
37
|
|
42
38
|
|
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
51
47
|
self.config = config
|
52
48
|
self.hidden_size = config.hidden_size
|
53
49
|
self.intermediate_size = config.intermediate_size
|
54
|
-
self.gate_up_proj = nn.Linear(
|
55
|
-
self.hidden_size, 2 * self.intermediate_size, bias=False
|
56
|
-
)
|
50
|
+
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
57
51
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
58
52
|
if config.hidden_act not in ["silu", "swish"]:
|
59
53
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
@@ -1,6 +1,4 @@
|
|
1
1
|
try:
|
2
|
-
from liger_kernel.transformers.trainer.orpo_trainer import
|
3
|
-
LigerORPOTrainer,
|
4
|
-
)
|
2
|
+
from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
|
5
3
|
except ImportError:
|
6
4
|
raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
|