liger-kernel 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +2 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  9. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  10. liger_kernel/chunked_loss/kto_loss.py +172 -0
  11. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  12. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  13. liger_kernel/env_report.py +5 -12
  14. liger_kernel/ops/cross_entropy.py +102 -51
  15. liger_kernel/ops/experimental/embedding.py +1 -3
  16. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  17. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  18. liger_kernel/ops/fused_linear_jsd.py +11 -29
  19. liger_kernel/ops/geglu.py +6 -17
  20. liger_kernel/ops/group_norm.py +11 -28
  21. liger_kernel/ops/jsd.py +2 -6
  22. liger_kernel/ops/kl_div.py +8 -11
  23. liger_kernel/ops/layer_norm.py +3 -5
  24. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  25. liger_kernel/ops/rms_norm.py +14 -32
  26. liger_kernel/ops/rope.py +31 -33
  27. liger_kernel/ops/swiglu.py +4 -8
  28. liger_kernel/ops/utils.py +2 -0
  29. liger_kernel/transformers/__init__.py +16 -24
  30. liger_kernel/transformers/auto_model.py +6 -13
  31. liger_kernel/transformers/cross_entropy.py +4 -6
  32. liger_kernel/transformers/experimental/embedding.py +1 -3
  33. liger_kernel/transformers/functional.py +11 -7
  34. liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
  35. liger_kernel/transformers/geglu.py +1 -4
  36. liger_kernel/transformers/group_norm.py +3 -9
  37. liger_kernel/transformers/jsd.py +1 -3
  38. liger_kernel/transformers/kl_div.py +1 -3
  39. liger_kernel/transformers/layer_norm.py +3 -9
  40. liger_kernel/transformers/model/gemma.py +18 -40
  41. liger_kernel/transformers/model/gemma2.py +19 -41
  42. liger_kernel/transformers/model/llama.py +22 -48
  43. liger_kernel/transformers/model/mistral.py +14 -26
  44. liger_kernel/transformers/model/mixtral.py +24 -54
  45. liger_kernel/transformers/model/mllama.py +16 -36
  46. liger_kernel/transformers/model/phi3.py +18 -40
  47. liger_kernel/transformers/model/qwen2.py +18 -40
  48. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  49. liger_kernel/transformers/monkey_patch.py +43 -117
  50. liger_kernel/transformers/rms_norm.py +4 -4
  51. liger_kernel/transformers/rope.py +2 -2
  52. liger_kernel/transformers/swiglu.py +2 -8
  53. liger_kernel/transformers/trainer/__init__.py +1 -3
  54. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  55. liger_kernel/triton/__init__.py +1 -3
  56. liger_kernel/triton/monkey_patch.py +1 -3
  57. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
  58. liger_kernel-0.5.3.dist-info/RECORD +69 -0
  59. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
  60. liger_kernel-0.5.2.dist-info/RECORD +0 -65
  61. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
  62. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
  63. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.mistral.modeling_mistral import (
8
- _CONFIG_FOR_DOC,
9
- MISTRAL_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
11
+ from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
12
+ from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import replace_return_docstrings
15
15
 
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
 
20
18
 
21
19
  @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
25
21
  def lce_forward(
26
22
  self,
27
23
  input_ids: torch.LongTensor = None,
@@ -65,19 +61,11 @@ def lce_forward(
65
61
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
66
62
  ```"""
67
63
 
68
- output_attentions = (
69
- output_attentions
70
- if output_attentions is not None
71
- else self.config.output_attentions
72
- )
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
73
65
  output_hidden_states = (
74
- output_hidden_states
75
- if output_hidden_states is not None
76
- else self.config.output_hidden_states
77
- )
78
- return_dict = (
79
- return_dict if return_dict is not None else self.config.use_return_dict
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
80
67
  )
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
81
69
 
82
70
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
83
71
  outputs = self.model(
@@ -1,27 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
6
- from transformers.models.mixtral.modeling_mixtral import (
7
- _CONFIG_FOR_DOC,
8
- MIXTRAL_INPUTS_DOCSTRING,
9
- load_balancing_loss_func,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
10
+ from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
11
+ from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
12
+ from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import replace_return_docstrings
15
15
 
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
 
20
18
 
21
19
  @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
20
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
25
21
  def lce_forward_deprecated(
26
22
  self,
27
23
  input_ids: torch.LongTensor = None,
@@ -38,7 +34,7 @@ def lce_forward_deprecated(
38
34
  cache_position: Optional[torch.LongTensor] = None,
39
35
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
40
36
  r"""
41
- Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
37
+ Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
42
38
 
43
39
 
44
40
  Args:
@@ -66,25 +62,15 @@ def lce_forward_deprecated(
66
62
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
67
63
  ```"""
68
64
 
69
- output_attentions = (
70
- output_attentions
71
- if output_attentions is not None
72
- else self.config.output_attentions
73
- )
65
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
74
66
  output_router_logits = (
75
- output_router_logits
76
- if output_router_logits is not None
77
- else self.config.output_router_logits
67
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
78
68
  )
79
69
 
80
70
  output_hidden_states = (
81
- output_hidden_states
82
- if output_hidden_states is not None
83
- else self.config.output_hidden_states
84
- )
85
- return_dict = (
86
- return_dict if return_dict is not None else self.config.use_return_dict
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
87
72
  )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
74
 
89
75
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
90
76
  outputs = self.model(
@@ -138,9 +124,7 @@ def lce_forward_deprecated(
138
124
  attention_mask,
139
125
  )
140
126
  if labels is not None:
141
- loss += self.router_aux_loss_coef * aux_loss.to(
142
- loss.device
143
- ) # make sure to reside in the same device
127
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
144
128
 
145
129
  if not return_dict:
146
130
  output = (logits,) + outputs[1:]
@@ -160,9 +144,7 @@ def lce_forward_deprecated(
160
144
 
161
145
 
162
146
  @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
163
- @replace_return_docstrings(
164
- output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
165
- )
147
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
166
148
  # Ignore copy
167
149
  def lce_forward(
168
150
  self,
@@ -212,25 +194,15 @@ def lce_forward(
212
194
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
213
195
  ```"""
214
196
 
215
- output_attentions = (
216
- output_attentions
217
- if output_attentions is not None
218
- else self.config.output_attentions
219
- )
197
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
220
198
  output_router_logits = (
221
- output_router_logits
222
- if output_router_logits is not None
223
- else self.config.output_router_logits
199
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
224
200
  )
225
201
 
226
202
  output_hidden_states = (
227
- output_hidden_states
228
- if output_hidden_states is not None
229
- else self.config.output_hidden_states
230
- )
231
- return_dict = (
232
- return_dict if return_dict is not None else self.config.use_return_dict
203
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
233
204
  )
205
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
206
 
235
207
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
236
208
  outputs = self.model(
@@ -288,9 +260,7 @@ def lce_forward(
288
260
  attention_mask,
289
261
  )
290
262
  if labels is not None:
291
- loss += self.router_aux_loss_coef * aux_loss.to(
292
- loss.device
293
- ) # make sure to reside in the same device
263
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
294
264
 
295
265
  if not return_dict:
296
266
  output = (logits,) + outputs[1:]
@@ -1,24 +1,22 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
11
  from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
8
- from transformers.utils import (
9
- add_start_docstrings_to_model_forward,
10
- replace_return_docstrings,
11
- )
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
12
14
 
13
- from liger_kernel.transformers.fused_linear_cross_entropy import (
14
- LigerFusedLinearCrossEntropyLoss,
15
- )
15
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
16
 
17
17
 
18
18
  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(
20
- output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
21
- )
19
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
22
20
  def lce_forward_deprecated(
23
21
  self,
24
22
  input_ids: torch.LongTensor = None,
@@ -66,19 +64,11 @@ def lce_forward_deprecated(
66
64
  I love the idea of snowflakes gently falling, each one
67
65
  ```
68
66
  """
69
- output_attentions = (
70
- output_attentions
71
- if output_attentions is not None
72
- else self.config.output_attentions
73
- )
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
74
68
  output_hidden_states = (
75
- output_hidden_states
76
- if output_hidden_states is not None
77
- else self.config.output_hidden_states
78
- )
79
- return_dict = (
80
- return_dict if return_dict is not None else self.config.use_return_dict
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81
70
  )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
72
 
83
73
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
84
74
  outputs = self.model(
@@ -143,9 +133,7 @@ def lce_forward_deprecated(
143
133
 
144
134
 
145
135
  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
146
- @replace_return_docstrings(
147
- output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
148
- )
136
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
149
137
  def lce_forward(
150
138
  self,
151
139
  input_ids: torch.LongTensor = None,
@@ -198,19 +186,11 @@ def lce_forward(
198
186
  I love the idea of snowflakes gently falling, each one
199
187
  ```
200
188
  """
201
- output_attentions = (
202
- output_attentions
203
- if output_attentions is not None
204
- else self.config.output_attentions
205
- )
189
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
206
190
  output_hidden_states = (
207
- output_hidden_states
208
- if output_hidden_states is not None
209
- else self.config.output_hidden_states
210
- )
211
- return_dict = (
212
- return_dict if return_dict is not None else self.config.use_return_dict
191
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
213
192
  )
193
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
214
194
 
215
195
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
216
196
  outputs = self.model(
@@ -1,26 +1,22 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
- from transformers.models.phi3.modeling_phi3 import (
7
- _CONFIG_FOR_DOC,
8
- PHI3_INPUTS_DOCSTRING,
9
- )
10
- from transformers.utils import (
11
- add_start_docstrings_to_model_forward,
12
- replace_return_docstrings,
13
- )
10
+ from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
+ from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
14
14
 
15
- from liger_kernel.transformers.fused_linear_cross_entropy import (
16
- LigerFusedLinearCrossEntropyLoss,
17
- )
15
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
16
 
19
17
 
20
18
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(
22
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
- )
19
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
24
20
  def lce_forward_deprecated(
25
21
  self,
26
22
  input_ids: torch.LongTensor = None,
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
64
60
  'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
65
61
  ```"""
66
62
 
67
- output_attentions = (
68
- output_attentions
69
- if output_attentions is not None
70
- else self.config.output_attentions
71
- )
63
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
72
64
  output_hidden_states = (
73
- output_hidden_states
74
- if output_hidden_states is not None
75
- else self.config.output_hidden_states
76
- )
77
- return_dict = (
78
- return_dict if return_dict is not None else self.config.use_return_dict
65
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
79
66
  )
67
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
68
 
81
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
70
  outputs = self.model(
@@ -138,9 +126,7 @@ def lce_forward_deprecated(
138
126
 
139
127
 
140
128
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
141
- @replace_return_docstrings(
142
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
143
- )
129
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
144
130
  def lce_forward(
145
131
  self,
146
132
  input_ids: torch.LongTensor = None,
@@ -202,19 +188,11 @@ def lce_forward(
202
188
  f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
203
189
  )
204
190
 
205
- output_attentions = (
206
- output_attentions
207
- if output_attentions is not None
208
- else self.config.output_attentions
209
- )
191
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
210
192
  output_hidden_states = (
211
- output_hidden_states
212
- if output_hidden_states is not None
213
- else self.config.output_hidden_states
214
- )
215
- return_dict = (
216
- return_dict if return_dict is not None else self.config.use_return_dict
193
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
217
194
  )
195
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
218
196
 
219
197
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
220
198
  outputs = self.model(
@@ -1,26 +1,22 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
- from transformers.models.qwen2.modeling_qwen2 import (
7
- _CONFIG_FOR_DOC,
8
- QWEN2_INPUTS_DOCSTRING,
9
- )
10
- from transformers.utils import (
11
- add_start_docstrings_to_model_forward,
12
- replace_return_docstrings,
13
- )
10
+ from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
+ from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
14
14
 
15
- from liger_kernel.transformers.fused_linear_cross_entropy import (
16
- LigerFusedLinearCrossEntropyLoss,
17
- )
15
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
16
 
19
17
 
20
18
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(
22
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
- )
19
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
24
20
  def lce_forward_deprecated(
25
21
  self,
26
22
  input_ids: torch.LongTensor = None,
@@ -63,19 +59,11 @@ def lce_forward_deprecated(
63
59
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
60
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
65
61
  ```"""
66
- output_attentions = (
67
- output_attentions
68
- if output_attentions is not None
69
- else self.config.output_attentions
70
- )
62
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
63
  output_hidden_states = (
72
- output_hidden_states
73
- if output_hidden_states is not None
74
- else self.config.output_hidden_states
75
- )
76
- return_dict = (
77
- return_dict if return_dict is not None else self.config.use_return_dict
64
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
65
  )
66
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
67
 
80
68
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
81
69
  outputs = self.model(
@@ -137,9 +125,7 @@ def lce_forward_deprecated(
137
125
 
138
126
 
139
127
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
140
- @replace_return_docstrings(
141
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
142
- )
128
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
143
129
  def lce_forward(
144
130
  self,
145
131
  input_ids: torch.LongTensor = None,
@@ -187,19 +173,11 @@ def lce_forward(
187
173
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
188
174
  ```"""
189
175
 
190
- output_attentions = (
191
- output_attentions
192
- if output_attentions is not None
193
- else self.config.output_attentions
194
- )
176
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
195
177
  output_hidden_states = (
196
- output_hidden_states
197
- if output_hidden_states is not None
198
- else self.config.output_hidden_states
199
- )
200
- return_dict = (
201
- return_dict if return_dict is not None else self.config.use_return_dict
178
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
202
179
  )
180
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
203
181
 
204
182
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
205
183
  outputs = self.model(
@@ -1,28 +1,24 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from packaging import version
5
9
  from torch.nn import CrossEntropyLoss
6
10
  from transformers import __version__ as transformers_version
7
- from transformers.models.qwen2_vl.modeling_qwen2_vl import (
8
- _CONFIG_FOR_DOC,
9
- QWEN2_VL_INPUTS_DOCSTRING,
10
- Qwen2VLCausalLMOutputWithPast,
11
- )
12
- from transformers.utils import (
13
- add_start_docstrings_to_model_forward,
14
- replace_return_docstrings,
15
- )
11
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
12
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
13
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
14
+ from transformers.utils import add_start_docstrings_to_model_forward
15
+ from transformers.utils import replace_return_docstrings
16
16
 
17
- from liger_kernel.transformers.fused_linear_cross_entropy import (
18
- LigerFusedLinearCrossEntropyLoss,
19
- )
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
20
18
 
21
19
 
22
20
  @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
23
- @replace_return_docstrings(
24
- output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
25
- )
21
+ @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
22
  def lce_forward(
27
23
  self,
28
24
  input_ids: torch.LongTensor = None,
@@ -40,6 +36,7 @@ def lce_forward(
40
36
  image_grid_thw: Optional[torch.LongTensor] = None,
41
37
  video_grid_thw: Optional[torch.LongTensor] = None,
42
38
  rope_deltas: Optional[torch.LongTensor] = None,
39
+ cache_position: Optional[torch.LongTensor] = None,
43
40
  ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
44
41
  r"""
45
42
  Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -82,19 +79,11 @@ def lce_forward(
82
79
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
83
80
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
84
81
  ```"""
85
- output_attentions = (
86
- output_attentions
87
- if output_attentions is not None
88
- else self.config.output_attentions
89
- )
82
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
90
83
  output_hidden_states = (
91
- output_hidden_states
92
- if output_hidden_states is not None
93
- else self.config.output_hidden_states
94
- )
95
- return_dict = (
96
- return_dict if return_dict is not None else self.config.use_return_dict
84
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
97
85
  )
86
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
98
87
 
99
88
  if inputs_embeds is None:
100
89
  inputs_embeds = self.model.embed_tokens(input_ids)
@@ -137,16 +126,30 @@ def lce_forward(
137
126
  if attention_mask is not None:
138
127
  attention_mask = attention_mask.to(inputs_embeds.device)
139
128
 
140
- if version.parse(transformers_version) > version.parse("4.46.2"):
129
+ if version.parse(transformers_version) > version.parse("4.46.3"):
141
130
  # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
142
131
  # https://github.com/huggingface/transformers/issues/33401
143
132
  # While correct, this breaks equivalence with past versions of Qwen2-VL from
144
133
  # transformers and leads to failed tests or users noticing differences in results.
145
134
  # TODO: remove above conditional when liger drops support for transformers<4.47.0
146
- if position_ids is None and input_ids is not None:
147
- position_ids, _ = self.get_rope_index(
148
- input_ids, image_grid_thw, video_grid_thw, attention_mask
149
- )
135
+ # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
136
+ if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
137
+ # calculate RoPE index once per generation in the pre-fill stage only
138
+ if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
139
+ position_ids, rope_deltas = self.get_rope_index(
140
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
141
+ )
142
+ self.rope_deltas = rope_deltas
143
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
144
+ else:
145
+ batch_size, seq_length, _ = inputs_embeds.shape
146
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
147
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
148
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
149
+ if cache_position is not None: # otherwise `deltas` is an int `0`
150
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
151
+ position_ids = position_ids.add(delta)
152
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
150
153
 
151
154
  outputs = self.model(
152
155
  input_ids=None,
@@ -158,6 +161,7 @@ def lce_forward(
158
161
  output_attentions=output_attentions,
159
162
  output_hidden_states=output_hidden_states,
160
163
  return_dict=return_dict,
164
+ cache_position=cache_position,
161
165
  )
162
166
 
163
167
  hidden_states = outputs[0]