liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +31 -33
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/rope.py +2 -2
  46. liger_kernel/transformers/swiglu.py +2 -8
  47. liger_kernel/transformers/trainer/__init__.py +1 -3
  48. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  49. liger_kernel/triton/__init__.py +1 -3
  50. liger_kernel/triton/monkey_patch.py +1 -3
  51. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
  52. liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
  53. liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD +0 -66
  54. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
- from liger_kernel.ops.fused_linear_cross_entropy import (
5
- LigerFusedLinearCrossEntropyFunction,
6
- )
4
+ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
7
5
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
6
  from liger_kernel.ops.geglu import LigerGELUMulFunction
9
7
  from liger_kernel.ops.group_norm import LigerGroupNormFunction
@@ -159,9 +157,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159
157
  return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160
158
 
161
159
 
162
- def liger_rms_norm(
163
- X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164
- ):
160
+ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
165
161
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166
162
 
167
163
 
@@ -2,9 +2,7 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.fused_linear_cross_entropy import (
6
- LigerFusedLinearCrossEntropyFunction,
7
- )
5
+ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
8
6
 
9
7
 
10
8
  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
@@ -25,9 +23,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
25
23
  "sum",
26
24
  "none",
27
25
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
28
- assert (
29
- softcap is None or softcap > 0
30
- ), f"softcap must greater than 0.0 or None. Got: {softcap}"
26
+ assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
31
27
  self.ignore_index = ignore_index
32
28
  self.lse_square_scale = lse_square_scale
33
29
  self.label_smoothing = label_smoothing
@@ -19,7 +19,4 @@ class LigerGEGLUMLP(nn.Module):
19
19
  # So we can safely assume we use tanh approximation form all the time
20
20
 
21
21
  def forward(self, x):
22
-
23
- return self.down_proj(
24
- LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
25
- )
22
+ return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
@@ -27,19 +27,13 @@ class LigerGroupNorm(nn.Module):
27
27
  self.num_channels = num_channels
28
28
  self.num_groups = num_groups
29
29
  self.eps = eps
30
- self.weight = nn.Parameter(
31
- torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels)
32
- )
33
- self.bias = nn.Parameter(
34
- torch.randn(num_channels) if bias else torch.zeros(num_channels)
35
- )
30
+ self.weight = nn.Parameter(torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels))
31
+ self.bias = nn.Parameter(torch.randn(num_channels) if bias else torch.zeros(num_channels))
36
32
  self.variance_epsilon = eps
37
33
 
38
34
  def forward(self, hidden_states):
39
35
  # hidden_states: (batch_size, num_channels, *)
40
- assert (
41
- hidden_states.dim() >= 3
42
- ), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
36
+ assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
43
37
  assert (
44
38
  hidden_states.size(1) == self.num_channels
45
39
  ), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
@@ -67,6 +67,4 @@ class LigerJSD(torch.nn.Module):
67
67
  log_p: torch.Tensor,
68
68
  shift_labels: Optional[torch.LongTensor] = None,
69
69
  ):
70
- return LigerJSDFunction.apply(
71
- log_q, log_p, shift_labels, self.beta, self.ignore_index
72
- )
70
+ return LigerJSDFunction.apply(log_q, log_p, shift_labels, self.beta, self.ignore_index)
@@ -9,6 +9,4 @@ class LigerKLDIVLoss(nn.KLDivLoss):
9
9
  self.eps = eps
10
10
 
11
11
  def forward(self, y_pred, y_true):
12
- return LigerKLDivLossFunction.apply(
13
- y_pred, y_true, self.reduction, self.log_target, self.eps
14
- )
12
+ return LigerKLDivLossFunction.apply(y_pred, y_true, self.reduction, self.log_target, self.eps)
@@ -13,18 +13,12 @@ class LigerLayerNorm(nn.Module):
13
13
  ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
14
14
  self.hidden_size = hidden_size
15
15
  self.eps = eps
16
- self.weight = nn.Parameter(
17
- torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
- )
19
- self.bias = nn.Parameter(
20
- torch.randn(hidden_size) if bias else torch.zeros(hidden_size)
21
- )
16
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
17
+ self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size))
22
18
  self.variance_epsilon = eps
23
19
 
24
20
  def forward(self, hidden_states):
25
- return LigerLayerNormFunction.apply(
26
- hidden_states, self.weight, self.bias, self.variance_epsilon
27
- )
21
+ return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon)
28
22
 
29
23
  def extra_repr(self):
30
24
  return f"{self.hidden_size}, eps={self.eps}"
@@ -1,27 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.gemma.modeling_gemma import (
8
- _CONFIG_FOR_DOC,
9
- GEMMA_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
11
+ from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
+ from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import replace_return_docstrings
15
15
 
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
 
20
18
 
21
19
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
25
21
  def lce_forward_deprecated(
26
22
  self,
27
23
  input_ids: torch.LongTensor = None,
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
64
60
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65
61
  "What is your favorite condiment?"
66
62
  ```"""
67
- output_attentions = (
68
- output_attentions
69
- if output_attentions is not None
70
- else self.config.output_attentions
71
- )
63
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
72
64
  output_hidden_states = (
73
- output_hidden_states
74
- if output_hidden_states is not None
75
- else self.config.output_hidden_states
76
- )
77
- return_dict = (
78
- return_dict if return_dict is not None else self.config.use_return_dict
65
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
79
66
  )
67
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
68
 
81
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
70
  outputs = self.model(
@@ -139,9 +127,7 @@ def lce_forward_deprecated(
139
127
 
140
128
 
141
129
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
142
- @replace_return_docstrings(
143
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
144
- )
130
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
145
131
  def lce_forward(
146
132
  self,
147
133
  input_ids: torch.LongTensor = None,
@@ -188,19 +174,11 @@ def lce_forward(
188
174
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
189
175
  "What is your favorite condiment?"
190
176
  ```"""
191
- output_attentions = (
192
- output_attentions
193
- if output_attentions is not None
194
- else self.config.output_attentions
195
- )
177
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
178
  output_hidden_states = (
197
- output_hidden_states
198
- if output_hidden_states is not None
199
- else self.config.output_hidden_states
200
- )
201
- return_dict = (
202
- return_dict if return_dict is not None else self.config.use_return_dict
179
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
203
180
  )
181
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
204
182
 
205
183
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
206
184
  outputs = self.model(
@@ -1,22 +1,20 @@
1
1
  import logging
2
- from typing import Optional, Tuple, Union
2
+
3
+ from typing import Optional
4
+ from typing import Tuple
5
+ from typing import Union
3
6
 
4
7
  import torch
8
+
5
9
  from torch.nn import CrossEntropyLoss
6
10
  from transformers.cache_utils import HybridCache
7
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
- from transformers.models.gemma2.modeling_gemma2 import (
9
- _CONFIG_FOR_DOC,
10
- GEMMA2_INPUTS_DOCSTRING,
11
- )
12
- from transformers.utils import (
13
- add_start_docstrings_to_model_forward,
14
- replace_return_docstrings,
15
- )
16
-
17
- from liger_kernel.transformers.fused_linear_cross_entropy import (
18
- LigerFusedLinearCrossEntropyLoss,
19
- )
12
+ from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
+ from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
+ from transformers.utils import add_start_docstrings_to_model_forward
15
+ from transformers.utils import replace_return_docstrings
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
20
18
 
21
19
  logger = logging.getLogger(__name__)
22
20
 
@@ -63,19 +61,11 @@ def lce_forward_deprecated(
63
61
  "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
64
62
  f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
65
63
  )
66
- output_attentions = (
67
- output_attentions
68
- if output_attentions is not None
69
- else self.config.output_attentions
70
- )
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
65
  output_hidden_states = (
72
- output_hidden_states
73
- if output_hidden_states is not None
74
- else self.config.output_hidden_states
75
- )
76
- return_dict = (
77
- return_dict if return_dict is not None else self.config.use_return_dict
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
67
  )
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
80
70
  outputs = self.model(
81
71
  input_ids=input_ids,
@@ -104,9 +94,7 @@ def lce_forward_deprecated(
104
94
  shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105
95
  shift_labels = shift_labels.view(-1)
106
96
 
107
- lce = LigerFusedLinearCrossEntropyLoss(
108
- softcap=self.config.final_logit_softcapping
109
- )
97
+ lce = LigerFusedLinearCrossEntropyLoss(softcap=self.config.final_logit_softcapping)
110
98
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
99
 
112
100
  else:
@@ -146,9 +134,7 @@ def lce_forward_deprecated(
146
134
 
147
135
 
148
136
  @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
149
- @replace_return_docstrings(
150
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
151
- )
137
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
152
138
  def lce_forward(
153
139
  self,
154
140
  input_ids: torch.LongTensor = None,
@@ -201,19 +187,11 @@ def lce_forward(
201
187
  "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
202
188
  f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
203
189
  )
204
- output_attentions = (
205
- output_attentions
206
- if output_attentions is not None
207
- else self.config.output_attentions
208
- )
190
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
209
191
  output_hidden_states = (
210
- output_hidden_states
211
- if output_hidden_states is not None
212
- else self.config.output_hidden_states
213
- )
214
- return_dict = (
215
- return_dict if return_dict is not None else self.config.use_return_dict
192
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
216
193
  )
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
217
195
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
218
196
  outputs = self.model(
219
197
  input_ids=input_ids,
@@ -1,30 +1,27 @@
1
- from typing import TYPE_CHECKING, List, Optional, Tuple, Union
1
+ from typing import TYPE_CHECKING
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Tuple
5
+ from typing import Union
2
6
 
3
7
  import torch
4
8
  import torch.nn.functional as F
9
+
5
10
  from torch.nn import CrossEntropyLoss
6
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.llama.modeling_llama import (
8
- _CONFIG_FOR_DOC,
9
- LLAMA_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
15
-
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
12
+ from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
+ from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
+ from transformers.utils import add_start_docstrings_to_model_forward
15
+ from transformers.utils import replace_return_docstrings
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  from transformers.cache_utils import Cache
22
21
 
23
22
 
24
23
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
25
- @replace_return_docstrings(
26
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
27
- )
24
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
28
25
  def lce_forward_deprecated(
29
26
  self,
30
27
  input_ids: torch.LongTensor = None,
@@ -67,19 +64,11 @@ def lce_forward_deprecated(
67
64
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
68
65
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
69
66
  ```"""
70
- output_attentions = (
71
- output_attentions
72
- if output_attentions is not None
73
- else self.config.output_attentions
74
- )
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
68
  output_hidden_states = (
76
- output_hidden_states
77
- if output_hidden_states is not None
78
- else self.config.output_hidden_states
79
- )
80
- return_dict = (
81
- return_dict if return_dict is not None else self.config.use_return_dict
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
82
70
  )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
72
 
84
73
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
85
74
  outputs = self.model(
@@ -113,13 +102,8 @@ def lce_forward_deprecated(
113
102
 
114
103
  else:
115
104
  if self.config.pretraining_tp > 1:
116
- lm_head_slices = self.lm_head.weight.split(
117
- self.vocab_size // self.config.pretraining_tp, dim=0
118
- )
119
- logits = [
120
- F.linear(hidden_states, lm_head_slices[i])
121
- for i in range(self.config.pretraining_tp)
122
- ]
105
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
106
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
123
107
  logits = torch.cat(logits, dim=-1)
124
108
  else:
125
109
  logits = self.lm_head(hidden_states)
@@ -151,9 +135,7 @@ def lce_forward_deprecated(
151
135
 
152
136
 
153
137
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
154
- @replace_return_docstrings(
155
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
156
- )
138
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
157
139
  def lce_forward(
158
140
  self,
159
141
  input_ids: torch.LongTensor = None,
@@ -201,19 +183,11 @@ def lce_forward(
201
183
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
202
184
  ```"""
203
185
 
204
- output_attentions = (
205
- output_attentions
206
- if output_attentions is not None
207
- else self.config.output_attentions
208
- )
186
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
209
187
  output_hidden_states = (
210
- output_hidden_states
211
- if output_hidden_states is not None
212
- else self.config.output_hidden_states
213
- )
214
- return_dict = (
215
- return_dict if return_dict is not None else self.config.use_return_dict
188
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
216
189
  )
190
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
217
191
 
218
192
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
219
193
  outputs = self.model(
@@ -1,27 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.mistral.modeling_mistral import (
8
- _CONFIG_FOR_DOC,
9
- MISTRAL_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
11
+ from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
12
+ from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import replace_return_docstrings
15
15
 
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
 
20
18
 
21
19
  @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
25
21
  def lce_forward(
26
22
  self,
27
23
  input_ids: torch.LongTensor = None,
@@ -65,19 +61,11 @@ def lce_forward(
65
61
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
66
62
  ```"""
67
63
 
68
- output_attentions = (
69
- output_attentions
70
- if output_attentions is not None
71
- else self.config.output_attentions
72
- )
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
73
65
  output_hidden_states = (
74
- output_hidden_states
75
- if output_hidden_states is not None
76
- else self.config.output_hidden_states
77
- )
78
- return_dict = (
79
- return_dict if return_dict is not None else self.config.use_return_dict
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
80
67
  )
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
81
69
 
82
70
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
83
71
  outputs = self.model(
@@ -1,27 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
6
- from transformers.models.mixtral.modeling_mixtral import (
7
- _CONFIG_FOR_DOC,
8
- MIXTRAL_INPUTS_DOCSTRING,
9
- load_balancing_loss_func,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
10
+ from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
11
+ from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
12
+ from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import replace_return_docstrings
15
15
 
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
 
20
18
 
21
19
  @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
20
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
25
21
  def lce_forward_deprecated(
26
22
  self,
27
23
  input_ids: torch.LongTensor = None,
@@ -66,25 +62,15 @@ def lce_forward_deprecated(
66
62
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
67
63
  ```"""
68
64
 
69
- output_attentions = (
70
- output_attentions
71
- if output_attentions is not None
72
- else self.config.output_attentions
73
- )
65
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
74
66
  output_router_logits = (
75
- output_router_logits
76
- if output_router_logits is not None
77
- else self.config.output_router_logits
67
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
78
68
  )
79
69
 
80
70
  output_hidden_states = (
81
- output_hidden_states
82
- if output_hidden_states is not None
83
- else self.config.output_hidden_states
84
- )
85
- return_dict = (
86
- return_dict if return_dict is not None else self.config.use_return_dict
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
87
72
  )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
74
 
89
75
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
90
76
  outputs = self.model(
@@ -138,9 +124,7 @@ def lce_forward_deprecated(
138
124
  attention_mask,
139
125
  )
140
126
  if labels is not None:
141
- loss += self.router_aux_loss_coef * aux_loss.to(
142
- loss.device
143
- ) # make sure to reside in the same device
127
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
144
128
 
145
129
  if not return_dict:
146
130
  output = (logits,) + outputs[1:]
@@ -160,9 +144,7 @@ def lce_forward_deprecated(
160
144
 
161
145
 
162
146
  @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
163
- @replace_return_docstrings(
164
- output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
165
- )
147
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
166
148
  # Ignore copy
167
149
  def lce_forward(
168
150
  self,
@@ -212,25 +194,15 @@ def lce_forward(
212
194
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
213
195
  ```"""
214
196
 
215
- output_attentions = (
216
- output_attentions
217
- if output_attentions is not None
218
- else self.config.output_attentions
219
- )
197
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
220
198
  output_router_logits = (
221
- output_router_logits
222
- if output_router_logits is not None
223
- else self.config.output_router_logits
199
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
224
200
  )
225
201
 
226
202
  output_hidden_states = (
227
- output_hidden_states
228
- if output_hidden_states is not None
229
- else self.config.output_hidden_states
230
- )
231
- return_dict = (
232
- return_dict if return_dict is not None else self.config.use_return_dict
203
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
233
204
  )
205
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
206
 
235
207
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
236
208
  outputs = self.model(
@@ -288,9 +260,7 @@ def lce_forward(
288
260
  attention_mask,
289
261
  )
290
262
  if labels is not None:
291
- loss += self.router_aux_loss_coef * aux_loss.to(
292
- loss.device
293
- ) # make sure to reside in the same device
263
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
294
264
 
295
265
  if not return_dict:
296
266
  output = (logits,) + outputs[1:]