liger-kernel-nightly 0.5.9.dev20250515163614__py3-none-any.whl → 0.5.9.dev20250516193902__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. liger_kernel/transformers/model/gemma.py +0 -8
  2. liger_kernel/transformers/model/gemma2.py +0 -6
  3. liger_kernel/transformers/model/gemma3.py +0 -8
  4. liger_kernel/transformers/model/glm4.py +0 -6
  5. liger_kernel/transformers/model/llama.py +0 -8
  6. liger_kernel/transformers/model/llava.py +0 -8
  7. liger_kernel/transformers/model/mistral.py +0 -6
  8. liger_kernel/transformers/model/mixtral.py +0 -8
  9. liger_kernel/transformers/model/mllama.py +0 -7
  10. liger_kernel/transformers/model/olmo2.py +0 -6
  11. liger_kernel/transformers/model/paligemma.py +0 -8
  12. liger_kernel/transformers/model/phi3.py +0 -8
  13. liger_kernel/transformers/model/qwen2.py +0 -8
  14. liger_kernel/transformers/model/qwen2_5_vl.py +0 -6
  15. liger_kernel/transformers/model/qwen2_vl.py +0 -6
  16. liger_kernel/transformers/model/qwen3.py +0 -6
  17. liger_kernel/transformers/model/qwen3_moe.py +0 -6
  18. {liger_kernel_nightly-0.5.9.dev20250515163614.dist-info → liger_kernel_nightly-0.5.9.dev20250516193902.dist-info}/METADATA +1 -1
  19. {liger_kernel_nightly-0.5.9.dev20250515163614.dist-info → liger_kernel_nightly-0.5.9.dev20250516193902.dist-info}/RECORD +23 -23
  20. {liger_kernel_nightly-0.5.9.dev20250515163614.dist-info → liger_kernel_nightly-0.5.9.dev20250516193902.dist-info}/LICENSE +0 -0
  21. {liger_kernel_nightly-0.5.9.dev20250515163614.dist-info → liger_kernel_nightly-0.5.9.dev20250516193902.dist-info}/NOTICE +0 -0
  22. {liger_kernel_nightly-0.5.9.dev20250515163614.dist-info → liger_kernel_nightly-0.5.9.dev20250516193902.dist-info}/WHEEL +0 -0
  23. {liger_kernel_nightly-0.5.9.dev20250515163614.dist-info → liger_kernel_nightly-0.5.9.dev20250516193902.dist-info}/top_level.txt +0 -0
@@ -8,18 +8,12 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
- from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
19
15
 
20
16
 
21
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
17
  def lce_forward_deprecated(
24
18
  self,
25
19
  input_ids: torch.LongTensor = None,
@@ -129,8 +123,6 @@ def lce_forward_deprecated(
129
123
 
130
124
 
131
125
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
132
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
133
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
134
126
  def lce_forward(
135
127
  self,
136
128
  input_ids: torch.LongTensor = None,
@@ -9,10 +9,6 @@ import torch
9
9
  from torch.nn import CrossEntropyLoss
10
10
  from transformers.cache_utils import HybridCache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
- from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
12
  from transformers.utils.deprecation import deprecate_kwarg
17
13
 
18
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -136,8 +132,6 @@ def lce_forward_deprecated(
136
132
 
137
133
 
138
134
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
139
- @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
140
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
141
135
  def lce_forward(
142
136
  self,
143
137
  input_ids: torch.LongTensor = None,
@@ -9,13 +9,9 @@ import torch.nn as nn
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.cache_utils import HybridCache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.gemma3.modeling_gemma3 import _CONFIG_FOR_DOC
13
- from transformers.models.gemma3.modeling_gemma3 import GEMMA3_INPUTS_DOCSTRING
14
12
  from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
15
- from transformers.utils import add_start_docstrings_to_model_forward
16
13
  from transformers.utils import is_torchdynamo_compiling
17
14
  from transformers.utils import logging
18
- from transformers.utils import replace_return_docstrings
19
15
  from transformers.utils.deprecation import deprecate_kwarg
20
16
 
21
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -25,8 +21,6 @@ logger = logging.get_logger(__name__)
25
21
 
26
22
 
27
23
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
28
- @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
29
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
30
24
  def causal_forward(
31
25
  self,
32
26
  input_ids: torch.LongTensor = None,
@@ -141,8 +135,6 @@ def causal_forward(
141
135
 
142
136
 
143
137
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
144
- @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
145
- @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
146
138
  def multimodal_forward(
147
139
  self,
148
140
  input_ids: torch.LongTensor = None,
@@ -6,18 +6,12 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
10
- from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
- from transformers.utils import replace_return_docstrings
13
9
  from transformers.utils.deprecation import deprecate_kwarg
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
14
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
- @add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -9,10 +9,6 @@ import torch.nn.functional as F
9
9
 
10
10
  from torch.nn import CrossEntropyLoss
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
- from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
12
  from transformers.utils.deprecation import deprecate_kwarg
17
13
 
18
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -22,8 +18,6 @@ if TYPE_CHECKING:
22
18
  from transformers.cache_utils import Cache
23
19
 
24
20
 
25
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
26
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
27
21
  def lce_forward_deprecated(
28
22
  self,
29
23
  input_ids: torch.LongTensor = None,
@@ -137,8 +131,6 @@ def lce_forward_deprecated(
137
131
 
138
132
 
139
133
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
140
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
141
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
142
134
  def lce_forward(
143
135
  self,
144
136
  input_ids: torch.LongTensor = None,
@@ -5,19 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
9
- from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
10
8
  from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
9
  from transformers.utils import is_torchdynamo_compiling
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
 
18
14
 
19
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward_deprecated(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -210,9 +204,7 @@ def lce_forward_deprecated(
210
204
  )
211
205
 
212
206
 
213
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
214
207
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
215
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
216
208
  def lce_forward(
217
209
  self,
218
210
  input_ids: torch.LongTensor = None,
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from transformers.cache_utils import Cache
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
11
- from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
13
 
18
14
 
19
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
20
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -7,19 +7,13 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
10
- from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
11
- from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
12
10
  from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
19
15
 
20
16
 
21
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
17
  def lce_forward_deprecated(
24
18
  self,
25
19
  input_ids: torch.LongTensor = None,
@@ -146,8 +140,6 @@ def lce_forward_deprecated(
146
140
 
147
141
 
148
142
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
149
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
150
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
151
143
  # Ignore copy
152
144
  def lce_forward(
153
145
  self,
@@ -8,17 +8,12 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
11
  from transformers.utils.deprecation import deprecate_kwarg
15
12
 
16
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
15
 
19
16
 
20
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
22
17
  def lce_forward_deprecated(
23
18
  self,
24
19
  input_ids: torch.LongTensor = None,
@@ -135,8 +130,6 @@ def lce_forward_deprecated(
135
130
 
136
131
 
137
132
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
138
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
139
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
140
133
  def lce_forward(
141
134
  self,
142
135
  input_ids: torch.LongTensor = None,
@@ -6,18 +6,12 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
10
- from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
- from transformers.utils import replace_return_docstrings
13
9
  from transformers.utils.deprecation import deprecate_kwarg
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
14
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
- @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -7,13 +7,9 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
- from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
11
- from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
12
10
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
11
  from transformers.utils import is_torchdynamo_compiling
15
12
  from transformers.utils import logging
16
- from transformers.utils import replace_return_docstrings
17
13
  from transformers.utils.deprecation import deprecate_kwarg
18
14
 
19
15
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -21,8 +17,6 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinea
21
17
  logger = logging.get_logger(__name__)
22
18
 
23
19
 
24
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
25
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
20
  def lce_forward_deprecated(
27
21
  self,
28
22
  input_ids: torch.LongTensor = None,
@@ -206,8 +200,6 @@ def lce_forward_deprecated(
206
200
 
207
201
 
208
202
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
209
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
210
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
211
203
  def lce_forward(
212
204
  self,
213
205
  input_ids: torch.LongTensor = None,
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
- from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward_deprecated(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -128,8 +122,6 @@ def lce_forward_deprecated(
128
122
 
129
123
 
130
124
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
131
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
132
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
133
125
  def lce_forward(
134
126
  self,
135
127
  input_ids: torch.LongTensor = None,
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
- from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward_deprecated(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -127,8 +121,6 @@ def lce_forward_deprecated(
127
121
 
128
122
 
129
123
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
130
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
131
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
132
124
  def lce_forward(
133
125
  self,
134
126
  input_ids: torch.LongTensor = None,
@@ -6,17 +6,11 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
10
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
11
9
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
- @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
20
14
  def lce_forward(
21
15
  self,
22
16
  input_ids: torch.LongTensor = None,
@@ -8,17 +8,11 @@ import torch
8
8
  from packaging import version
9
9
  from torch.nn import CrossEntropyLoss
10
10
  from transformers import __version__ as transformers_version
11
- from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
12
- from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
13
11
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
12
 
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -5,16 +5,10 @@ from typing import Union
5
5
  import torch
6
6
 
7
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
- from transformers.models.qwen3.modeling_qwen3 import _CONFIG_FOR_DOC
9
- from transformers.models.qwen3.modeling_qwen3 import QWEN3_INPUTS_DOCSTRING
10
- from transformers.utils import add_start_docstrings_to_model_forward
11
- from transformers.utils import replace_return_docstrings
12
8
 
13
9
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
10
 
15
11
 
16
- @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
17
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
18
12
  def lce_forward(
19
13
  self,
20
14
  input_ids: Optional[torch.LongTensor] = None,
@@ -7,16 +7,10 @@ import torch
7
7
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
8
8
  from transformers.modeling_outputs import MoeModelOutputWithPast
9
9
  from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
10
- from transformers.models.qwen3_moe.modeling_qwen3_moe import _CONFIG_FOR_DOC
11
- from transformers.models.qwen3_moe.modeling_qwen3_moe import QWEN3_MOE_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
- @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
20
14
  def lce_forward(
21
15
  self,
22
16
  input_ids: Optional[torch.LongTensor] = None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250515163614
3
+ Version: 0.5.9.dev20250516193902
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -57,31 +57,31 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
57
57
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
58
58
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
59
59
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
- liger_kernel/transformers/model/gemma.py,sha256=nMUY2Iw7j6a-fOUqYBlfzIPznpKPKVa2DMBIZqCVfuI,10087
61
- liger_kernel/transformers/model/gemma2.py,sha256=eulrUbh1DEMpMR6Lupx69kL-FeuRDP19mVoW1gc7keY,11194
62
- liger_kernel/transformers/model/gemma3.py,sha256=wGSNqaLRRgIGQ_r9esyhDezm2SkAGZflopoWoWR-nYY,16226
63
- liger_kernel/transformers/model/glm4.py,sha256=rtyMTtzgh_ncZ7DsfNxRJoUUm7xlDMKGzNqlxXjdAJk,5452
64
- liger_kernel/transformers/model/llama.py,sha256=F8cvDAlf4NeKESdGEFXs8m3ue2F8i0h3aV2LricMqoM,10764
65
- liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
60
+ liger_kernel/transformers/model/gemma.py,sha256=gi5fVeFPryoYy0_T3rzU2wm7v_xiJnLCnTkQYR86_nk,9504
61
+ liger_kernel/transformers/model/gemma2.py,sha256=61uH9JSZM6cPDoGHr2kNUVq2O4A3XIy2Qea36XhkkPQ,10761
62
+ liger_kernel/transformers/model/gemma3.py,sha256=e-o7rcOJAJMZDJBB-blkLz5ildWjuDneSkakqwrADBc,15630
63
+ liger_kernel/transformers/model/glm4.py,sha256=yYbQEcSrSTMleNTpwJosMhBf4VC9-79EyC__utmOSFg,5031
64
+ liger_kernel/transformers/model/llama.py,sha256=pkkoKip94p3hNWA11cIVvTdNqCRB8FgR039pZWLqNeA,10181
65
+ liger_kernel/transformers/model/llava.py,sha256=RjLVnpHtOClc1jJkkPSqke7fcgWC3Jjh1rrGyvh5kb8,17008
66
66
  liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
67
- liger_kernel/transformers/model/mistral.py,sha256=1AcwJT9WOIpHkpu4Njs35ZryiGyW8ygERYmGqLz2Z4o,5752
68
- liger_kernel/transformers/model/mixtral.py,sha256=URMzPLU1akf1H4hHXalCyfbVGUldRx8_jqdrZfM7Y-w,11773
69
- liger_kernel/transformers/model/mllama.py,sha256=v_ayi6m4sC6AVKTrrLHF4W5HVaL86AYQNBqdWuTTOTw,11579
70
- liger_kernel/transformers/model/olmo2.py,sha256=Kb6sGPsQS970GsYmWoT0DC2DFiXQ9Yjyxr8FRnT_8tQ,5460
71
- liger_kernel/transformers/model/paligemma.py,sha256=GNReT6tVZt3ON6aaa9ovg8mnu1hYocSx9OhgC7b-_28,19191
72
- liger_kernel/transformers/model/phi3.py,sha256=TSeHK8H0mnS2esJaZI3lxmo5X3-Uwtd_TsrgvJRkm3s,10726
73
- liger_kernel/transformers/model/qwen2.py,sha256=bEusb6vrVbagtSUHyntpi9j0x79IrZ1NP8iA5GR5Ryw,10015
74
- liger_kernel/transformers/model/qwen2_5_vl.py,sha256=oACIsTpg9_GdoSvekCyXLhJkuCpQEiFOTzKj7cjgi2E,9413
75
- liger_kernel/transformers/model/qwen2_vl.py,sha256=F6DeQ65wPtcpeQJZ9a3SJZKkQ-e24SRLdYUgC-_jT-k,9809
76
- liger_kernel/transformers/model/qwen3.py,sha256=JdIeh0fvDLdGs8nk4_eHrovHCNa09VG15D4aa0X0mwI,5084
77
- liger_kernel/transformers/model/qwen3_moe.py,sha256=2EFIltbaQ6y8ksYDTk0NC0b2Zdbir7eW15avY4XisLQ,5917
67
+ liger_kernel/transformers/model/mistral.py,sha256=0lt1Jq37zWjxLZF-Vuj9jUyIEnWlMuT7PB5xB42KXBs,5313
68
+ liger_kernel/transformers/model/mixtral.py,sha256=KpxDHtj7OCrZj_KrUWByRKM3A_x9o1S26rU3XGd1Ro8,11170
69
+ liger_kernel/transformers/model/mllama.py,sha256=eElsJpBjdLfWhAZsYcfWnp_1tAf6t8jvliszu-v7sVg,11054
70
+ liger_kernel/transformers/model/olmo2.py,sha256=FH_BY6pTiLgcjqsO1rprl9vcL_iZgBHBszelXgVj47Y,5033
71
+ liger_kernel/transformers/model/paligemma.py,sha256=zXVV7FkhBnuHrbMg-CTOK21B90but6NqFd0DCeEefQE,18562
72
+ liger_kernel/transformers/model/phi3.py,sha256=jYFqWcfP9wT9WUZeOC0SWjX_ZtWzQSDHDWH40m91TGE,10150
73
+ liger_kernel/transformers/model/qwen2.py,sha256=b0fF5HX009VRrAGu9O2pG73YDDR05x_oy7JV9dvHuak,9432
74
+ liger_kernel/transformers/model/qwen2_5_vl.py,sha256=F3lnFpKxTyij7ToEWc0hmXXyrdSsnbEfPSNCh9tAF0Y,8946
75
+ liger_kernel/transformers/model/qwen2_vl.py,sha256=q3AMpxFfwHjaMu9Q3jpwpMPRzrE-eLqppg_8Z0ixjaQ,9357
76
+ liger_kernel/transformers/model/qwen3.py,sha256=u_0cCRwr1jcwMkSknbBVb9my1OepCGU718uxKhNUOVM,4657
77
+ liger_kernel/transformers/model/qwen3_moe.py,sha256=lIWGunVtNP-d7VfRvEGY820howzecb10g6ZeWRgsfl8,5463
78
78
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
79
79
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
80
80
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
81
81
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
82
- liger_kernel_nightly-0.5.9.dev20250515163614.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
83
- liger_kernel_nightly-0.5.9.dev20250515163614.dist-info/METADATA,sha256=plSThqoxpxBqQ5hkqApOjytL6Scb5kWrj6xE9MhrG9k,23874
84
- liger_kernel_nightly-0.5.9.dev20250515163614.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
85
- liger_kernel_nightly-0.5.9.dev20250515163614.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
86
- liger_kernel_nightly-0.5.9.dev20250515163614.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
87
- liger_kernel_nightly-0.5.9.dev20250515163614.dist-info/RECORD,,
82
+ liger_kernel_nightly-0.5.9.dev20250516193902.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
83
+ liger_kernel_nightly-0.5.9.dev20250516193902.dist-info/METADATA,sha256=XhWjsVPmpt1AME9f-1gVh46eTdK0D0nnusCJ8wtSBRk,23874
84
+ liger_kernel_nightly-0.5.9.dev20250516193902.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
85
+ liger_kernel_nightly-0.5.9.dev20250516193902.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
86
+ liger_kernel_nightly-0.5.9.dev20250516193902.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
87
+ liger_kernel_nightly-0.5.9.dev20250516193902.dist-info/RECORD,,