liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (57) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +31 -33
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/rope.py +2 -2
  46. liger_kernel/transformers/swiglu.py +2 -8
  47. liger_kernel/transformers/trainer/__init__.py +1 -3
  48. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  49. liger_kernel/triton/__init__.py +1 -3
  50. liger_kernel/triton/monkey_patch.py +1 -3
  51. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
  52. liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
  53. liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD +0 -66
  54. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,22 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
11
  from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
8
- from transformers.utils import (
9
- add_start_docstrings_to_model_forward,
10
- replace_return_docstrings,
11
- )
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
12
14
 
13
- from liger_kernel.transformers.fused_linear_cross_entropy import (
14
- LigerFusedLinearCrossEntropyLoss,
15
- )
15
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
16
 
17
17
 
18
18
  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(
20
- output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
21
- )
19
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
22
20
  def lce_forward_deprecated(
23
21
  self,
24
22
  input_ids: torch.LongTensor = None,
@@ -66,19 +64,11 @@ def lce_forward_deprecated(
66
64
  I love the idea of snowflakes gently falling, each one
67
65
  ```
68
66
  """
69
- output_attentions = (
70
- output_attentions
71
- if output_attentions is not None
72
- else self.config.output_attentions
73
- )
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
74
68
  output_hidden_states = (
75
- output_hidden_states
76
- if output_hidden_states is not None
77
- else self.config.output_hidden_states
78
- )
79
- return_dict = (
80
- return_dict if return_dict is not None else self.config.use_return_dict
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81
70
  )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
72
 
83
73
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
84
74
  outputs = self.model(
@@ -143,9 +133,7 @@ def lce_forward_deprecated(
143
133
 
144
134
 
145
135
  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
146
- @replace_return_docstrings(
147
- output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
148
- )
136
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
149
137
  def lce_forward(
150
138
  self,
151
139
  input_ids: torch.LongTensor = None,
@@ -198,19 +186,11 @@ def lce_forward(
198
186
  I love the idea of snowflakes gently falling, each one
199
187
  ```
200
188
  """
201
- output_attentions = (
202
- output_attentions
203
- if output_attentions is not None
204
- else self.config.output_attentions
205
- )
189
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
206
190
  output_hidden_states = (
207
- output_hidden_states
208
- if output_hidden_states is not None
209
- else self.config.output_hidden_states
210
- )
211
- return_dict = (
212
- return_dict if return_dict is not None else self.config.use_return_dict
191
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
213
192
  )
193
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
214
194
 
215
195
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
216
196
  outputs = self.model(
@@ -1,26 +1,22 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
- from transformers.models.phi3.modeling_phi3 import (
7
- _CONFIG_FOR_DOC,
8
- PHI3_INPUTS_DOCSTRING,
9
- )
10
- from transformers.utils import (
11
- add_start_docstrings_to_model_forward,
12
- replace_return_docstrings,
13
- )
10
+ from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
+ from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
14
14
 
15
- from liger_kernel.transformers.fused_linear_cross_entropy import (
16
- LigerFusedLinearCrossEntropyLoss,
17
- )
15
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
16
 
19
17
 
20
18
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(
22
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
- )
19
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
24
20
  def lce_forward_deprecated(
25
21
  self,
26
22
  input_ids: torch.LongTensor = None,
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
64
60
  'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
65
61
  ```"""
66
62
 
67
- output_attentions = (
68
- output_attentions
69
- if output_attentions is not None
70
- else self.config.output_attentions
71
- )
63
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
72
64
  output_hidden_states = (
73
- output_hidden_states
74
- if output_hidden_states is not None
75
- else self.config.output_hidden_states
76
- )
77
- return_dict = (
78
- return_dict if return_dict is not None else self.config.use_return_dict
65
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
79
66
  )
67
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
68
 
81
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
70
  outputs = self.model(
@@ -138,9 +126,7 @@ def lce_forward_deprecated(
138
126
 
139
127
 
140
128
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
141
- @replace_return_docstrings(
142
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
143
- )
129
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
144
130
  def lce_forward(
145
131
  self,
146
132
  input_ids: torch.LongTensor = None,
@@ -202,19 +188,11 @@ def lce_forward(
202
188
  f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
203
189
  )
204
190
 
205
- output_attentions = (
206
- output_attentions
207
- if output_attentions is not None
208
- else self.config.output_attentions
209
- )
191
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
210
192
  output_hidden_states = (
211
- output_hidden_states
212
- if output_hidden_states is not None
213
- else self.config.output_hidden_states
214
- )
215
- return_dict = (
216
- return_dict if return_dict is not None else self.config.use_return_dict
193
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
217
194
  )
195
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
218
196
 
219
197
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
220
198
  outputs = self.model(
@@ -1,26 +1,22 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
- from transformers.models.qwen2.modeling_qwen2 import (
7
- _CONFIG_FOR_DOC,
8
- QWEN2_INPUTS_DOCSTRING,
9
- )
10
- from transformers.utils import (
11
- add_start_docstrings_to_model_forward,
12
- replace_return_docstrings,
13
- )
10
+ from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
+ from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
14
14
 
15
- from liger_kernel.transformers.fused_linear_cross_entropy import (
16
- LigerFusedLinearCrossEntropyLoss,
17
- )
15
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
16
 
19
17
 
20
18
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(
22
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
- )
19
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
24
20
  def lce_forward_deprecated(
25
21
  self,
26
22
  input_ids: torch.LongTensor = None,
@@ -63,19 +59,11 @@ def lce_forward_deprecated(
63
59
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
60
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
65
61
  ```"""
66
- output_attentions = (
67
- output_attentions
68
- if output_attentions is not None
69
- else self.config.output_attentions
70
- )
62
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
63
  output_hidden_states = (
72
- output_hidden_states
73
- if output_hidden_states is not None
74
- else self.config.output_hidden_states
75
- )
76
- return_dict = (
77
- return_dict if return_dict is not None else self.config.use_return_dict
64
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
65
  )
66
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
67
 
80
68
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
81
69
  outputs = self.model(
@@ -137,9 +125,7 @@ def lce_forward_deprecated(
137
125
 
138
126
 
139
127
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
140
- @replace_return_docstrings(
141
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
142
- )
128
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
143
129
  def lce_forward(
144
130
  self,
145
131
  input_ids: torch.LongTensor = None,
@@ -187,19 +173,11 @@ def lce_forward(
187
173
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
188
174
  ```"""
189
175
 
190
- output_attentions = (
191
- output_attentions
192
- if output_attentions is not None
193
- else self.config.output_attentions
194
- )
176
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
195
177
  output_hidden_states = (
196
- output_hidden_states
197
- if output_hidden_states is not None
198
- else self.config.output_hidden_states
199
- )
200
- return_dict = (
201
- return_dict if return_dict is not None else self.config.use_return_dict
178
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
202
179
  )
180
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
203
181
 
204
182
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
205
183
  outputs = self.model(
@@ -1,28 +1,24 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from packaging import version
5
9
  from torch.nn import CrossEntropyLoss
6
10
  from transformers import __version__ as transformers_version
7
- from transformers.models.qwen2_vl.modeling_qwen2_vl import (
8
- _CONFIG_FOR_DOC,
9
- QWEN2_VL_INPUTS_DOCSTRING,
10
- Qwen2VLCausalLMOutputWithPast,
11
- )
12
- from transformers.utils import (
13
- add_start_docstrings_to_model_forward,
14
- replace_return_docstrings,
15
- )
11
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
12
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
13
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
14
+ from transformers.utils import add_start_docstrings_to_model_forward
15
+ from transformers.utils import replace_return_docstrings
16
16
 
17
- from liger_kernel.transformers.fused_linear_cross_entropy import (
18
- LigerFusedLinearCrossEntropyLoss,
19
- )
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
20
18
 
21
19
 
22
20
  @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
23
- @replace_return_docstrings(
24
- output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
25
- )
21
+ @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
22
  def lce_forward(
27
23
  self,
28
24
  input_ids: torch.LongTensor = None,
@@ -82,19 +78,11 @@ def lce_forward(
82
78
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
83
79
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
84
80
  ```"""
85
- output_attentions = (
86
- output_attentions
87
- if output_attentions is not None
88
- else self.config.output_attentions
89
- )
81
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
90
82
  output_hidden_states = (
91
- output_hidden_states
92
- if output_hidden_states is not None
93
- else self.config.output_hidden_states
94
- )
95
- return_dict = (
96
- return_dict if return_dict is not None else self.config.use_return_dict
83
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
97
84
  )
85
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
98
86
 
99
87
  if inputs_embeds is None:
100
88
  inputs_embeds = self.model.embed_tokens(input_ids)
@@ -144,9 +132,7 @@ def lce_forward(
144
132
  # transformers and leads to failed tests or users noticing differences in results.
145
133
  # TODO: remove above conditional when liger drops support for transformers<4.47.0
146
134
  if position_ids is None and input_ids is not None:
147
- position_ids, _ = self.get_rope_index(
148
- input_ids, image_grid_thw, video_grid_thw, attention_mask
149
- )
135
+ position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
150
136
 
151
137
  outputs = self.model(
152
138
  input_ids=None,