liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +8 -24
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/swiglu.py +2 -8
  46. liger_kernel/transformers/trainer/__init__.py +1 -3
  47. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  48. liger_kernel/triton/__init__.py +1 -3
  49. liger_kernel/triton/monkey_patch.py +1 -3
  50. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
  51. liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
  52. liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
  53. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -13,18 +13,12 @@ class LigerLayerNorm(nn.Module):
13
13
  ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
14
14
  self.hidden_size = hidden_size
15
15
  self.eps = eps
16
- self.weight = nn.Parameter(
17
- torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
- )
19
- self.bias = nn.Parameter(
20
- torch.randn(hidden_size) if bias else torch.zeros(hidden_size)
21
- )
16
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
17
+ self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size))
22
18
  self.variance_epsilon = eps
23
19
 
24
20
  def forward(self, hidden_states):
25
- return LigerLayerNormFunction.apply(
26
- hidden_states, self.weight, self.bias, self.variance_epsilon
27
- )
21
+ return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon)
28
22
 
29
23
  def extra_repr(self):
30
24
  return f"{self.hidden_size}, eps={self.eps}"
@@ -1,27 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.gemma.modeling_gemma import (
8
- _CONFIG_FOR_DOC,
9
- GEMMA_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
11
+ from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
+ from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import replace_return_docstrings
15
15
 
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
 
20
18
 
21
19
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
25
21
  def lce_forward_deprecated(
26
22
  self,
27
23
  input_ids: torch.LongTensor = None,
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
64
60
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65
61
  "What is your favorite condiment?"
66
62
  ```"""
67
- output_attentions = (
68
- output_attentions
69
- if output_attentions is not None
70
- else self.config.output_attentions
71
- )
63
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
72
64
  output_hidden_states = (
73
- output_hidden_states
74
- if output_hidden_states is not None
75
- else self.config.output_hidden_states
76
- )
77
- return_dict = (
78
- return_dict if return_dict is not None else self.config.use_return_dict
65
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
79
66
  )
67
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
68
 
81
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
70
  outputs = self.model(
@@ -139,9 +127,7 @@ def lce_forward_deprecated(
139
127
 
140
128
 
141
129
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
142
- @replace_return_docstrings(
143
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
144
- )
130
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
145
131
  def lce_forward(
146
132
  self,
147
133
  input_ids: torch.LongTensor = None,
@@ -188,19 +174,11 @@ def lce_forward(
188
174
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
189
175
  "What is your favorite condiment?"
190
176
  ```"""
191
- output_attentions = (
192
- output_attentions
193
- if output_attentions is not None
194
- else self.config.output_attentions
195
- )
177
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
178
  output_hidden_states = (
197
- output_hidden_states
198
- if output_hidden_states is not None
199
- else self.config.output_hidden_states
200
- )
201
- return_dict = (
202
- return_dict if return_dict is not None else self.config.use_return_dict
179
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
203
180
  )
181
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
204
182
 
205
183
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
206
184
  outputs = self.model(
@@ -1,22 +1,20 @@
1
1
  import logging
2
- from typing import Optional, Tuple, Union
2
+
3
+ from typing import Optional
4
+ from typing import Tuple
5
+ from typing import Union
3
6
 
4
7
  import torch
8
+
5
9
  from torch.nn import CrossEntropyLoss
6
10
  from transformers.cache_utils import HybridCache
7
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
- from transformers.models.gemma2.modeling_gemma2 import (
9
- _CONFIG_FOR_DOC,
10
- GEMMA2_INPUTS_DOCSTRING,
11
- )
12
- from transformers.utils import (
13
- add_start_docstrings_to_model_forward,
14
- replace_return_docstrings,
15
- )
16
-
17
- from liger_kernel.transformers.fused_linear_cross_entropy import (
18
- LigerFusedLinearCrossEntropyLoss,
19
- )
12
+ from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
+ from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
+ from transformers.utils import add_start_docstrings_to_model_forward
15
+ from transformers.utils import replace_return_docstrings
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
20
18
 
21
19
  logger = logging.getLogger(__name__)
22
20
 
@@ -63,19 +61,11 @@ def lce_forward_deprecated(
63
61
  "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
64
62
  f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
65
63
  )
66
- output_attentions = (
67
- output_attentions
68
- if output_attentions is not None
69
- else self.config.output_attentions
70
- )
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
65
  output_hidden_states = (
72
- output_hidden_states
73
- if output_hidden_states is not None
74
- else self.config.output_hidden_states
75
- )
76
- return_dict = (
77
- return_dict if return_dict is not None else self.config.use_return_dict
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
67
  )
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
80
70
  outputs = self.model(
81
71
  input_ids=input_ids,
@@ -104,9 +94,7 @@ def lce_forward_deprecated(
104
94
  shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105
95
  shift_labels = shift_labels.view(-1)
106
96
 
107
- lce = LigerFusedLinearCrossEntropyLoss(
108
- softcap=self.config.final_logit_softcapping
109
- )
97
+ lce = LigerFusedLinearCrossEntropyLoss(softcap=self.config.final_logit_softcapping)
110
98
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
99
 
112
100
  else:
@@ -146,9 +134,7 @@ def lce_forward_deprecated(
146
134
 
147
135
 
148
136
  @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
149
- @replace_return_docstrings(
150
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
151
- )
137
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
152
138
  def lce_forward(
153
139
  self,
154
140
  input_ids: torch.LongTensor = None,
@@ -201,19 +187,11 @@ def lce_forward(
201
187
  "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
202
188
  f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
203
189
  )
204
- output_attentions = (
205
- output_attentions
206
- if output_attentions is not None
207
- else self.config.output_attentions
208
- )
190
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
209
191
  output_hidden_states = (
210
- output_hidden_states
211
- if output_hidden_states is not None
212
- else self.config.output_hidden_states
213
- )
214
- return_dict = (
215
- return_dict if return_dict is not None else self.config.use_return_dict
192
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
216
193
  )
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
217
195
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
218
196
  outputs = self.model(
219
197
  input_ids=input_ids,
@@ -1,30 +1,27 @@
1
- from typing import TYPE_CHECKING, List, Optional, Tuple, Union
1
+ from typing import TYPE_CHECKING
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Tuple
5
+ from typing import Union
2
6
 
3
7
  import torch
4
8
  import torch.nn.functional as F
9
+
5
10
  from torch.nn import CrossEntropyLoss
6
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.llama.modeling_llama import (
8
- _CONFIG_FOR_DOC,
9
- LLAMA_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
15
-
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
12
+ from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
+ from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
+ from transformers.utils import add_start_docstrings_to_model_forward
15
+ from transformers.utils import replace_return_docstrings
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  from transformers.cache_utils import Cache
22
21
 
23
22
 
24
23
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
25
- @replace_return_docstrings(
26
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
27
- )
24
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
28
25
  def lce_forward_deprecated(
29
26
  self,
30
27
  input_ids: torch.LongTensor = None,
@@ -67,19 +64,11 @@ def lce_forward_deprecated(
67
64
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
68
65
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
69
66
  ```"""
70
- output_attentions = (
71
- output_attentions
72
- if output_attentions is not None
73
- else self.config.output_attentions
74
- )
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
68
  output_hidden_states = (
76
- output_hidden_states
77
- if output_hidden_states is not None
78
- else self.config.output_hidden_states
79
- )
80
- return_dict = (
81
- return_dict if return_dict is not None else self.config.use_return_dict
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
82
70
  )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
72
 
84
73
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
85
74
  outputs = self.model(
@@ -113,13 +102,8 @@ def lce_forward_deprecated(
113
102
 
114
103
  else:
115
104
  if self.config.pretraining_tp > 1:
116
- lm_head_slices = self.lm_head.weight.split(
117
- self.vocab_size // self.config.pretraining_tp, dim=0
118
- )
119
- logits = [
120
- F.linear(hidden_states, lm_head_slices[i])
121
- for i in range(self.config.pretraining_tp)
122
- ]
105
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
106
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
123
107
  logits = torch.cat(logits, dim=-1)
124
108
  else:
125
109
  logits = self.lm_head(hidden_states)
@@ -151,9 +135,7 @@ def lce_forward_deprecated(
151
135
 
152
136
 
153
137
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
154
- @replace_return_docstrings(
155
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
156
- )
138
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
157
139
  def lce_forward(
158
140
  self,
159
141
  input_ids: torch.LongTensor = None,
@@ -201,19 +183,11 @@ def lce_forward(
201
183
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
202
184
  ```"""
203
185
 
204
- output_attentions = (
205
- output_attentions
206
- if output_attentions is not None
207
- else self.config.output_attentions
208
- )
186
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
209
187
  output_hidden_states = (
210
- output_hidden_states
211
- if output_hidden_states is not None
212
- else self.config.output_hidden_states
213
- )
214
- return_dict = (
215
- return_dict if return_dict is not None else self.config.use_return_dict
188
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
216
189
  )
190
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
217
191
 
218
192
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
219
193
  outputs = self.model(
@@ -1,27 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.mistral.modeling_mistral import (
8
- _CONFIG_FOR_DOC,
9
- MISTRAL_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
11
+ from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
12
+ from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import replace_return_docstrings
15
15
 
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
 
20
18
 
21
19
  @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
25
21
  def lce_forward(
26
22
  self,
27
23
  input_ids: torch.LongTensor = None,
@@ -65,19 +61,11 @@ def lce_forward(
65
61
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
66
62
  ```"""
67
63
 
68
- output_attentions = (
69
- output_attentions
70
- if output_attentions is not None
71
- else self.config.output_attentions
72
- )
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
73
65
  output_hidden_states = (
74
- output_hidden_states
75
- if output_hidden_states is not None
76
- else self.config.output_hidden_states
77
- )
78
- return_dict = (
79
- return_dict if return_dict is not None else self.config.use_return_dict
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
80
67
  )
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
81
69
 
82
70
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
83
71
  outputs = self.model(
@@ -1,27 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
6
- from transformers.models.mixtral.modeling_mixtral import (
7
- _CONFIG_FOR_DOC,
8
- MIXTRAL_INPUTS_DOCSTRING,
9
- load_balancing_loss_func,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
10
+ from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
11
+ from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
12
+ from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import replace_return_docstrings
15
15
 
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
 
20
18
 
21
19
  @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
20
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
25
21
  def lce_forward_deprecated(
26
22
  self,
27
23
  input_ids: torch.LongTensor = None,
@@ -66,25 +62,15 @@ def lce_forward_deprecated(
66
62
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
67
63
  ```"""
68
64
 
69
- output_attentions = (
70
- output_attentions
71
- if output_attentions is not None
72
- else self.config.output_attentions
73
- )
65
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
74
66
  output_router_logits = (
75
- output_router_logits
76
- if output_router_logits is not None
77
- else self.config.output_router_logits
67
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
78
68
  )
79
69
 
80
70
  output_hidden_states = (
81
- output_hidden_states
82
- if output_hidden_states is not None
83
- else self.config.output_hidden_states
84
- )
85
- return_dict = (
86
- return_dict if return_dict is not None else self.config.use_return_dict
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
87
72
  )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
74
 
89
75
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
90
76
  outputs = self.model(
@@ -138,9 +124,7 @@ def lce_forward_deprecated(
138
124
  attention_mask,
139
125
  )
140
126
  if labels is not None:
141
- loss += self.router_aux_loss_coef * aux_loss.to(
142
- loss.device
143
- ) # make sure to reside in the same device
127
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
144
128
 
145
129
  if not return_dict:
146
130
  output = (logits,) + outputs[1:]
@@ -160,9 +144,7 @@ def lce_forward_deprecated(
160
144
 
161
145
 
162
146
  @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
163
- @replace_return_docstrings(
164
- output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
165
- )
147
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
166
148
  # Ignore copy
167
149
  def lce_forward(
168
150
  self,
@@ -212,25 +194,15 @@ def lce_forward(
212
194
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
213
195
  ```"""
214
196
 
215
- output_attentions = (
216
- output_attentions
217
- if output_attentions is not None
218
- else self.config.output_attentions
219
- )
197
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
220
198
  output_router_logits = (
221
- output_router_logits
222
- if output_router_logits is not None
223
- else self.config.output_router_logits
199
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
224
200
  )
225
201
 
226
202
  output_hidden_states = (
227
- output_hidden_states
228
- if output_hidden_states is not None
229
- else self.config.output_hidden_states
230
- )
231
- return_dict = (
232
- return_dict if return_dict is not None else self.config.use_return_dict
203
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
233
204
  )
205
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
206
 
235
207
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
236
208
  outputs = self.model(
@@ -288,9 +260,7 @@ def lce_forward(
288
260
  attention_mask,
289
261
  )
290
262
  if labels is not None:
291
- loss += self.router_aux_loss_coef * aux_loss.to(
292
- loss.device
293
- ) # make sure to reside in the same device
263
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
294
264
 
295
265
  if not return_dict:
296
266
  output = (logits,) + outputs[1:]
@@ -1,24 +1,22 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
11
  from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
8
- from transformers.utils import (
9
- add_start_docstrings_to_model_forward,
10
- replace_return_docstrings,
11
- )
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
12
14
 
13
- from liger_kernel.transformers.fused_linear_cross_entropy import (
14
- LigerFusedLinearCrossEntropyLoss,
15
- )
15
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
16
 
17
17
 
18
18
  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(
20
- output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
21
- )
19
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
22
20
  def lce_forward_deprecated(
23
21
  self,
24
22
  input_ids: torch.LongTensor = None,
@@ -66,19 +64,11 @@ def lce_forward_deprecated(
66
64
  I love the idea of snowflakes gently falling, each one
67
65
  ```
68
66
  """
69
- output_attentions = (
70
- output_attentions
71
- if output_attentions is not None
72
- else self.config.output_attentions
73
- )
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
74
68
  output_hidden_states = (
75
- output_hidden_states
76
- if output_hidden_states is not None
77
- else self.config.output_hidden_states
78
- )
79
- return_dict = (
80
- return_dict if return_dict is not None else self.config.use_return_dict
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81
70
  )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
72
 
83
73
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
84
74
  outputs = self.model(
@@ -143,9 +133,7 @@ def lce_forward_deprecated(
143
133
 
144
134
 
145
135
  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
146
- @replace_return_docstrings(
147
- output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
148
- )
136
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
149
137
  def lce_forward(
150
138
  self,
151
139
  input_ids: torch.LongTensor = None,
@@ -198,19 +186,11 @@ def lce_forward(
198
186
  I love the idea of snowflakes gently falling, each one
199
187
  ```
200
188
  """
201
- output_attentions = (
202
- output_attentions
203
- if output_attentions is not None
204
- else self.config.output_attentions
205
- )
189
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
206
190
  output_hidden_states = (
207
- output_hidden_states
208
- if output_hidden_states is not None
209
- else self.config.output_hidden_states
210
- )
211
- return_dict = (
212
- return_dict if return_dict is not None else self.config.use_return_dict
191
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
213
192
  )
193
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
214
194
 
215
195
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
216
196
  outputs = self.model(