liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +8 -24
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/swiglu.py +2 -8
  46. liger_kernel/transformers/trainer/__init__.py +1 -3
  47. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  48. liger_kernel/triton/__init__.py +1 -3
  49. liger_kernel/triton/monkey_patch.py +1 -3
  50. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
  51. liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
  52. liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
  53. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,22 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
- from transformers.models.phi3.modeling_phi3 import (
7
- _CONFIG_FOR_DOC,
8
- PHI3_INPUTS_DOCSTRING,
9
- )
10
- from transformers.utils import (
11
- add_start_docstrings_to_model_forward,
12
- replace_return_docstrings,
13
- )
10
+ from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
+ from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
14
14
 
15
- from liger_kernel.transformers.fused_linear_cross_entropy import (
16
- LigerFusedLinearCrossEntropyLoss,
17
- )
15
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
16
 
19
17
 
20
18
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(
22
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
- )
19
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
24
20
  def lce_forward_deprecated(
25
21
  self,
26
22
  input_ids: torch.LongTensor = None,
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
64
60
  'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
65
61
  ```"""
66
62
 
67
- output_attentions = (
68
- output_attentions
69
- if output_attentions is not None
70
- else self.config.output_attentions
71
- )
63
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
72
64
  output_hidden_states = (
73
- output_hidden_states
74
- if output_hidden_states is not None
75
- else self.config.output_hidden_states
76
- )
77
- return_dict = (
78
- return_dict if return_dict is not None else self.config.use_return_dict
65
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
79
66
  )
67
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
68
 
81
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
70
  outputs = self.model(
@@ -138,9 +126,7 @@ def lce_forward_deprecated(
138
126
 
139
127
 
140
128
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
141
- @replace_return_docstrings(
142
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
143
- )
129
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
144
130
  def lce_forward(
145
131
  self,
146
132
  input_ids: torch.LongTensor = None,
@@ -202,19 +188,11 @@ def lce_forward(
202
188
  f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
203
189
  )
204
190
 
205
- output_attentions = (
206
- output_attentions
207
- if output_attentions is not None
208
- else self.config.output_attentions
209
- )
191
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
210
192
  output_hidden_states = (
211
- output_hidden_states
212
- if output_hidden_states is not None
213
- else self.config.output_hidden_states
214
- )
215
- return_dict = (
216
- return_dict if return_dict is not None else self.config.use_return_dict
193
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
217
194
  )
195
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
218
196
 
219
197
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
220
198
  outputs = self.model(
@@ -1,26 +1,22 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
- from transformers.models.qwen2.modeling_qwen2 import (
7
- _CONFIG_FOR_DOC,
8
- QWEN2_INPUTS_DOCSTRING,
9
- )
10
- from transformers.utils import (
11
- add_start_docstrings_to_model_forward,
12
- replace_return_docstrings,
13
- )
10
+ from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
+ from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
14
14
 
15
- from liger_kernel.transformers.fused_linear_cross_entropy import (
16
- LigerFusedLinearCrossEntropyLoss,
17
- )
15
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
16
 
19
17
 
20
18
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(
22
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
- )
19
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
24
20
  def lce_forward_deprecated(
25
21
  self,
26
22
  input_ids: torch.LongTensor = None,
@@ -63,19 +59,11 @@ def lce_forward_deprecated(
63
59
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
60
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
65
61
  ```"""
66
- output_attentions = (
67
- output_attentions
68
- if output_attentions is not None
69
- else self.config.output_attentions
70
- )
62
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
63
  output_hidden_states = (
72
- output_hidden_states
73
- if output_hidden_states is not None
74
- else self.config.output_hidden_states
75
- )
76
- return_dict = (
77
- return_dict if return_dict is not None else self.config.use_return_dict
64
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
65
  )
66
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
67
 
80
68
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
81
69
  outputs = self.model(
@@ -137,9 +125,7 @@ def lce_forward_deprecated(
137
125
 
138
126
 
139
127
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
140
- @replace_return_docstrings(
141
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
142
- )
128
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
143
129
  def lce_forward(
144
130
  self,
145
131
  input_ids: torch.LongTensor = None,
@@ -187,19 +173,11 @@ def lce_forward(
187
173
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
188
174
  ```"""
189
175
 
190
- output_attentions = (
191
- output_attentions
192
- if output_attentions is not None
193
- else self.config.output_attentions
194
- )
176
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
195
177
  output_hidden_states = (
196
- output_hidden_states
197
- if output_hidden_states is not None
198
- else self.config.output_hidden_states
199
- )
200
- return_dict = (
201
- return_dict if return_dict is not None else self.config.use_return_dict
178
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
202
179
  )
180
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
203
181
 
204
182
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
205
183
  outputs = self.model(
@@ -1,28 +1,24 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from packaging import version
5
9
  from torch.nn import CrossEntropyLoss
6
10
  from transformers import __version__ as transformers_version
7
- from transformers.models.qwen2_vl.modeling_qwen2_vl import (
8
- _CONFIG_FOR_DOC,
9
- QWEN2_VL_INPUTS_DOCSTRING,
10
- Qwen2VLCausalLMOutputWithPast,
11
- )
12
- from transformers.utils import (
13
- add_start_docstrings_to_model_forward,
14
- replace_return_docstrings,
15
- )
11
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
12
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
13
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
14
+ from transformers.utils import add_start_docstrings_to_model_forward
15
+ from transformers.utils import replace_return_docstrings
16
16
 
17
- from liger_kernel.transformers.fused_linear_cross_entropy import (
18
- LigerFusedLinearCrossEntropyLoss,
19
- )
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
20
18
 
21
19
 
22
20
  @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
23
- @replace_return_docstrings(
24
- output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
25
- )
21
+ @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
22
  def lce_forward(
27
23
  self,
28
24
  input_ids: torch.LongTensor = None,
@@ -82,19 +78,11 @@ def lce_forward(
82
78
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
83
79
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
84
80
  ```"""
85
- output_attentions = (
86
- output_attentions
87
- if output_attentions is not None
88
- else self.config.output_attentions
89
- )
81
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
90
82
  output_hidden_states = (
91
- output_hidden_states
92
- if output_hidden_states is not None
93
- else self.config.output_hidden_states
94
- )
95
- return_dict = (
96
- return_dict if return_dict is not None else self.config.use_return_dict
83
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
97
84
  )
85
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
98
86
 
99
87
  if inputs_embeds is None:
100
88
  inputs_embeds = self.model.embed_tokens(input_ids)
@@ -144,9 +132,7 @@ def lce_forward(
144
132
  # transformers and leads to failed tests or users noticing differences in results.
145
133
  # TODO: remove above conditional when liger drops support for transformers<4.47.0
146
134
  if position_ids is None and input_ids is not None:
147
- position_ids, _ = self.get_rope_index(
148
- input_ids, image_grid_thw, video_grid_thw, attention_mask
149
- )
135
+ position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
150
136
 
151
137
  outputs = self.model(
152
138
  input_ids=None,
@@ -1,9 +1,11 @@
1
1
  import inspect
2
2
  import logging
3
+
3
4
  from functools import partial
4
5
  from typing import Callable
5
6
 
6
7
  import transformers
8
+
7
9
  from packaging import version
8
10
  from transformers import PreTrainedModel
9
11
 
@@ -12,38 +14,24 @@ from liger_kernel.transformers.functional import liger_cross_entropy
12
14
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
13
15
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
14
16
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
15
- from liger_kernel.transformers.model.gemma import (
16
- lce_forward_deprecated as gemma_lce_forward_deprecated,
17
- )
17
+ from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
18
18
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
- from liger_kernel.transformers.model.gemma2 import (
20
- lce_forward_deprecated as gemma2_lce_forward_deprected,
21
- )
19
+ from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
22
20
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
23
- from liger_kernel.transformers.model.llama import (
24
- lce_forward_deprecated as llama_lce_forward_deprecated,
25
- )
21
+ from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
26
22
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
27
23
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
28
- from liger_kernel.transformers.model.mixtral import (
29
- lce_forward_deprecated as mixtral_lce_forward_deprecated,
30
- )
24
+ from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
31
25
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
32
- from liger_kernel.transformers.model.phi3 import (
33
- lce_forward_deprecated as phi3_lce_forward_deprecated,
34
- )
26
+ from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
35
27
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
36
- from liger_kernel.transformers.model.qwen2 import (
37
- lce_forward_deprecated as qwen2_lce_forward_deprecated,
38
- )
28
+ from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
39
29
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
40
30
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
41
31
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
42
- from liger_kernel.transformers.swiglu import (
43
- LigerBlockSparseTop2MLP,
44
- LigerPhi3SwiGLUMLP,
45
- LigerSwiGLUMLP,
46
- )
32
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
33
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
34
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
47
35
 
48
36
  transformer_version = version.parse(transformers.__version__)
49
37
 
@@ -57,23 +45,17 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
57
45
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
58
46
 
59
47
 
60
- def _patch_rms_norm_module(
61
- module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
62
- ):
48
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
63
49
  module.offset = offset
64
50
  module.casting_mode = casting_mode
65
- module.variance_epsilon = (
66
- getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
67
- )
51
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
68
52
  module.in_place = in_place
69
53
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
70
54
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
71
55
 
72
56
 
73
57
  def _patch_layer_norm_module(module, eps=1e-6):
74
- module.variance_epsilon = (
75
- getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
76
- )
58
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
77
59
  module.hidden_size = module.normalized_shape
78
60
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
79
61
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
@@ -145,9 +127,7 @@ def apply_liger_kernel_to_llama(
145
127
 
146
128
  for decoder_layer in base_model.layers:
147
129
  if swiglu:
148
- _bind_method_to_module(
149
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
150
- )
130
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
151
131
  if rms_norm:
152
132
  _patch_rms_norm_module(decoder_layer.input_layernorm)
153
133
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -184,17 +164,13 @@ def apply_liger_kernel_to_mllama(
184
164
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
185
165
 
186
166
  from transformers.models.mllama import modeling_mllama
187
- from transformers.models.mllama.modeling_mllama import (
188
- MllamaForCausalLM,
189
- MllamaForConditionalGeneration,
190
- MllamaTextModel,
191
- MllamaVisionModel,
192
- )
167
+ from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
168
+ from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
169
+ from transformers.models.mllama.modeling_mllama import MllamaTextModel
170
+ from transformers.models.mllama.modeling_mllama import MllamaVisionModel
193
171
 
194
172
  from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
195
- from liger_kernel.transformers.model.mllama import (
196
- lce_forward_deprecated as mllama_lce_forward_deprecated,
197
- )
173
+ from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
198
174
 
199
175
  if rope:
200
176
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -241,9 +217,7 @@ def apply_liger_kernel_to_mllama(
241
217
  _patch_rms_norm_module(text_model.norm)
242
218
  for decoder_layer in text_model.layers:
243
219
  if swiglu:
244
- _bind_method_to_module(
245
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
246
- )
220
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
247
221
  if rms_norm:
248
222
  _patch_rms_norm_module(decoder_layer.input_layernorm)
249
223
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -317,9 +291,7 @@ def apply_liger_kernel_to_mistral(
317
291
 
318
292
  for decoder_layer in base_model.layers:
319
293
  if swiglu:
320
- _bind_method_to_module(
321
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
322
- )
294
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
323
295
  if rms_norm:
324
296
  _patch_rms_norm_module(decoder_layer.input_layernorm)
325
297
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -391,9 +363,7 @@ def apply_liger_kernel_to_mixtral(
391
363
  for decoder_layer in base_model.layers:
392
364
  if swiglu:
393
365
  for expert in decoder_layer.block_sparse_moe.experts:
394
- _bind_method_to_module(
395
- expert, "forward", LigerBlockSparseTop2MLP.forward
396
- )
366
+ _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
397
367
  if rms_norm:
398
368
  _patch_rms_norm_module(decoder_layer.input_layernorm)
399
369
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -431,12 +401,8 @@ def apply_liger_kernel_to_gemma(
431
401
  from transformers.models.gemma.modeling_gemma import GemmaModel
432
402
 
433
403
  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
434
- LigerRMSNormForGemma = partial(
435
- LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
436
- )
437
- _patch_rms_norm_module_for_gemma = partial(
438
- _patch_rms_norm_module, casting_mode="gemma", offset=1.0
439
- )
404
+ LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
405
+ _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
440
406
 
441
407
  if rope:
442
408
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -471,9 +437,7 @@ def apply_liger_kernel_to_gemma(
471
437
 
472
438
  for decoder_layer in base_model.layers:
473
439
  if geglu:
474
- _bind_method_to_module(
475
- decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
476
- )
440
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
477
441
  if rms_norm:
478
442
  _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
479
443
  _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
@@ -510,9 +474,7 @@ def apply_liger_kernel_to_gemma2(
510
474
  from transformers.models.gemma2 import modeling_gemma2
511
475
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
512
476
 
513
- LigerRMSNormForGemma2 = partial(
514
- LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
515
- )
477
+ LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
516
478
  _patch_rms_norm_module_for_gemma2 = partial(
517
479
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
518
480
  )
@@ -551,20 +513,12 @@ def apply_liger_kernel_to_gemma2(
551
513
 
552
514
  for decoder_layer in base_model.layers:
553
515
  if geglu:
554
- _bind_method_to_module(
555
- decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
556
- )
516
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
557
517
  if rms_norm:
558
518
  _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
559
- _patch_rms_norm_module_for_gemma2(
560
- decoder_layer.post_attention_layernorm
561
- )
562
- _patch_rms_norm_module_for_gemma2(
563
- decoder_layer.pre_feedforward_layernorm
564
- )
565
- _patch_rms_norm_module_for_gemma2(
566
- decoder_layer.post_feedforward_layernorm
567
- )
519
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
520
+ _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
521
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
568
522
 
569
523
 
570
524
  def apply_liger_kernel_to_qwen2(
@@ -633,9 +587,7 @@ def apply_liger_kernel_to_qwen2(
633
587
 
634
588
  for decoder_layer in base_model.layers:
635
589
  if swiglu:
636
- _bind_method_to_module(
637
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
638
- )
590
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
639
591
  if rms_norm:
640
592
  _patch_rms_norm_module(decoder_layer.input_layernorm)
641
593
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -674,14 +626,10 @@ def apply_liger_kernel_to_qwen2_vl(
674
626
  from transformers.models.qwen2_vl import modeling_qwen2_vl
675
627
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
676
628
 
677
- from liger_kernel.transformers.model.qwen2_vl import (
678
- lce_forward as qwen2_vl_lce_forward,
679
- )
629
+ from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
680
630
 
681
631
  if rope:
682
- modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = (
683
- liger_multimodal_rotary_pos_emb
684
- )
632
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
685
633
  if rms_norm:
686
634
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
687
635
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
@@ -712,9 +660,7 @@ def apply_liger_kernel_to_qwen2_vl(
712
660
  _patch_rms_norm_module(base_model.norm)
713
661
  for decoder_layer in base_model.layers:
714
662
  if swiglu:
715
- _bind_method_to_module(
716
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
717
- )
663
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
718
664
  if rms_norm:
719
665
  _patch_rms_norm_module(decoder_layer.input_layernorm)
720
666
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -783,9 +729,7 @@ def apply_liger_kernel_to_phi3(
783
729
 
784
730
  for decoder_layer in base_model.layers:
785
731
  if swiglu:
786
- _bind_method_to_module(
787
- decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
788
- )
732
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
789
733
  if rms_norm:
790
734
  _patch_rms_norm_module(decoder_layer.input_layernorm)
791
735
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -826,24 +770,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
826
770
  return
827
771
 
828
772
  if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
829
- logger.info(
830
- f"There are currently no Liger kernels supported for model type: {model_type}."
831
- )
773
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
832
774
  return
833
775
 
834
776
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
835
777
  apply_fn_signature = inspect.signature(apply_fn)
836
778
 
837
779
  # Filter out the keyword arguments that are not supported by the apply function
838
- applicable_kwargs = {
839
- key: value
840
- for key, value in kwargs.items()
841
- if key in apply_fn_signature.parameters
842
- }
780
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
843
781
 
844
- logger.info(
845
- f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
846
- )
782
+ logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
847
783
 
848
784
  # Assume this is invoked pre-model initialization, so we only need to patch transformers code
849
785
  apply_fn(**applicable_kwargs)
@@ -857,20 +793,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
857
793
  - model: the model instance to apply Liger kernels to
858
794
  - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
859
795
  """
860
- model_type = getattr(model, "config", None) and getattr(
861
- model.config, "model_type", None
862
- )
796
+ model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
863
797
 
864
798
  if not model_type:
865
- logger.info(
866
- "Model type could not be determined from model config. No Liger kernels will be applied."
867
- )
799
+ logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
868
800
  return
869
801
 
870
802
  if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
871
- logger.info(
872
- f"There are currently no Liger kernels supported for model type: {model_type}."
873
- )
803
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
874
804
  return
875
805
 
876
806
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
@@ -878,11 +808,7 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
878
808
  apply_fn_signature = inspect.signature(apply_fn)
879
809
 
880
810
  # Filter out the keyword arguments that are not supported by the apply function
881
- applicable_kwargs = {
882
- key: value
883
- for key, value in kwargs.items()
884
- if key in apply_fn_signature.parameters
885
- }
811
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
886
812
  logger.info(
887
813
  f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
888
814
  )
@@ -19,9 +19,7 @@ class LigerRMSNorm(nn.Module):
19
19
  "ones",
20
20
  "zeros",
21
21
  ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
22
- self.weight = nn.Parameter(
23
- torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
24
- )
22
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
25
23
  self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
26
24
  eps,
27
25
  offset,
@@ -40,4 +38,6 @@ class LigerRMSNorm(nn.Module):
40
38
  )
41
39
 
42
40
  def extra_repr(self):
43
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
41
+ return (
42
+ f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
43
+ )
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
16
16
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
17
17
 
18
18
  def forward(self, x):
19
-
20
- return self.down_proj(
21
- LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
22
- )
19
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
23
20
 
24
21
 
25
22
  class LigerBlockSparseTop2MLP(nn.Module):
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
36
33
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
37
34
 
38
35
  def forward(self, x):
39
-
40
36
  return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
41
37
 
42
38
 
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
51
47
  self.config = config
52
48
  self.hidden_size = config.hidden_size
53
49
  self.intermediate_size = config.intermediate_size
54
- self.gate_up_proj = nn.Linear(
55
- self.hidden_size, 2 * self.intermediate_size, bias=False
56
- )
50
+ self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
57
51
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
58
52
  if config.hidden_act not in ["silu", "swish"]:
59
53
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
@@ -1,6 +1,4 @@
1
1
  try:
2
- from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401
3
- LigerORPOTrainer,
4
- )
2
+ from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
5
3
  except ImportError:
6
4
  raise ImportError("Please `pip install trl` to use LigerORPOTrainer")