liger-kernel-nightly 0.3.1.dev20241101044713__py3-none-any.whl → 0.3.1.dev20241102065152__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (12) hide show
  1. liger_kernel/transformers/model/llama.py +22 -19
  2. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/METADATA +2 -2
  3. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/RECORD +12 -12
  4. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/LICENSE +0 -0
  5. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/LICENSE-Apache-2.0 +0 -0
  6. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/LICENSE-MIT-AutoAWQ +0 -0
  7. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  8. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/LICENSE-MIT-llmc +0 -0
  9. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/LICENSE-MIT-triton +0 -0
  10. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/NOTICE +0 -0
  11. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/WHEEL +0 -0
  12. {liger_kernel_nightly-0.3.1.dev20241101044713.dist-info → liger_kernel_nightly-0.3.1.dev20241102065152.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
  import torch.nn.functional as F
@@ -18,6 +18,10 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
18
18
  )
19
19
 
20
20
 
21
+ if TYPE_CHECKING:
22
+ from transformers.cache_utils import Cache
23
+
24
+
21
25
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
22
26
  @replace_return_docstrings(
23
27
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
@@ -27,7 +31,7 @@ def lce_forward_deprecated(
27
31
  input_ids: torch.LongTensor = None,
28
32
  attention_mask: Optional[torch.Tensor] = None,
29
33
  position_ids: Optional[torch.LongTensor] = None,
30
- past_key_values: Optional[List[torch.FloatTensor]] = None,
34
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
31
35
  inputs_embeds: Optional[torch.FloatTensor] = None,
32
36
  labels: Optional[torch.LongTensor] = None,
33
37
  use_cache: Optional[bool] = None,
@@ -153,19 +157,19 @@ def lce_forward_deprecated(
153
157
  )
154
158
  def lce_forward(
155
159
  self,
156
- input_ids=None,
157
- attention_mask=None,
158
- position_ids=None,
159
- past_key_values=None,
160
- inputs_embeds=None,
161
- labels=None,
162
- use_cache=None,
163
- output_attentions=None,
164
- output_hidden_states=None,
165
- return_dict=None,
166
- cache_position=None,
167
- num_logits_to_keep=0,
168
- **kwargs,
160
+ input_ids: torch.LongTensor = None,
161
+ attention_mask: Optional[torch.Tensor] = None,
162
+ position_ids: Optional[torch.LongTensor] = None,
163
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
164
+ inputs_embeds: Optional[torch.FloatTensor] = None,
165
+ labels: Optional[torch.LongTensor] = None,
166
+ use_cache: Optional[bool] = None,
167
+ output_attentions: Optional[bool] = None,
168
+ output_hidden_states: Optional[bool] = None,
169
+ return_dict: Optional[bool] = None,
170
+ cache_position: Optional[torch.LongTensor] = None,
171
+ num_logits_to_keep: int = 0,
172
+ **loss_kwargs,
169
173
  ) -> Union[Tuple, CausalLMOutputWithPast]:
170
174
  r"""
171
175
  Args:
@@ -224,7 +228,6 @@ def lce_forward(
224
228
  output_hidden_states=output_hidden_states,
225
229
  return_dict=return_dict,
226
230
  cache_position=cache_position,
227
- **kwargs,
228
231
  )
229
232
 
230
233
  hidden_states = outputs[0]
@@ -245,12 +248,12 @@ def lce_forward(
245
248
  shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
246
249
  shift_labels = shift_labels.view(-1)
247
250
 
248
- reduction = "sum" if "num_items_in_batch" in kwargs else "mean"
251
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
249
252
  lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
250
253
 
251
254
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
252
255
  if reduction == "sum":
253
- loss /= kwargs["num_items_in_batch"]
256
+ loss /= loss_kwargs["num_items_in_batch"]
254
257
 
255
258
  else: # if in inference mode materialize logits
256
259
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
@@ -259,7 +262,7 @@ def lce_forward(
259
262
  logits=logits,
260
263
  labels=labels,
261
264
  vocab_size=self.config.vocab_size,
262
- **kwargs,
265
+ **loss_kwargs,
263
266
  )
264
267
 
265
268
  if not return_dict:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.3.1.dev20241101044713
3
+ Version: 0.3.1.dev20241102065152
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -36,7 +36,7 @@ License-File: LICENSE-MIT-llmc
36
36
  License-File: LICENSE-MIT-triton
37
37
  License-File: NOTICE
38
38
  Requires-Dist: torch>=2.1.2
39
- Requires-Dist: triton>=2.3.0
39
+ Requires-Dist: triton>=2.3.1
40
40
  Provides-Extra: dev
41
41
  Requires-Dist: transformers>=4.44.2; extra == "dev"
42
42
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
@@ -31,7 +31,7 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
31
31
  liger_kernel/transformers/experimental/embedding.py,sha256=HpckiAMKM8-SRxKDcGTqortVxnjhwpZsfsp9lfjqfeM,895
32
32
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
33
  liger_kernel/transformers/model/gemma.py,sha256=EcdkGbSj_qroTDFl0Sc_HLyDyY0xcDhwrgkM_wkXnw8,4987
34
- liger_kernel/transformers/model/llama.py,sha256=XZ5rBck_2uVHHKQ5bsbVPio_Pd545BjTwTpAA0uLZAA,10028
34
+ liger_kernel/transformers/model/llama.py,sha256=KSCUkUnHrhL0jI4NRtJrPC0tbp-oFBCJEiqgga0HuTU,10427
35
35
  liger_kernel/transformers/model/mistral.py,sha256=_MQJrDntlxBO5cJwgTjr2rk2nNd5FAXVnzcTg_PEekQ,5079
36
36
  liger_kernel/transformers/model/mixtral.py,sha256=51FghRY8aGBWat7KSgTeFDqdStDiXY3dEJepByNhEOE,5847
37
37
  liger_kernel/transformers/model/mllama.py,sha256=S00P0pJrGHOWBx170TPYZbQ0djv0__m8Dqv1FvKZUyE,5926
@@ -40,14 +40,14 @@ liger_kernel/transformers/model/qwen2.py,sha256=3inWFXGHYT7wA10OR6bq3mDUBrr10AS5
40
40
  liger_kernel/transformers/model/qwen2_vl.py,sha256=ymsm9aQpSUiSU12GY8FO608p9dSHOz4TCnNI1htX5bk,6975
41
41
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
42
42
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
43
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
44
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/LICENSE-Apache-2.0,sha256=NRaCIsL9eblGS35gk4WKTC0usNYnR_mgRHJTKqz2_UE,11348
45
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/LICENSE-MIT-AutoAWQ,sha256=pfiOyInrAPY3xQbvV1i-gOqNZK7QEyIepT1IbqOYYYo,1067
46
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/LICENSE-MIT-Efficient-Cross-Entropy,sha256=PaC9HqyFYTy-ClS0H8Zfa2motJuTppjECXmjHwJcaOk,1063
47
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/LICENSE-MIT-llmc,sha256=kyFLt_XUcXS88CuxQt5-PjOcLjpJP2m-T4gtqZf3GLc,1071
48
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/LICENSE-MIT-triton,sha256=wL6W8IwsKiyHtzXubg8TCXhRZuo8S83EPdqXffYtqWg,1131
49
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/METADATA,sha256=wLLMrLqw2pA47vWt1bqdc6lEC9UptNFvpB5_KnqBwj4,27717
50
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
51
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
52
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
53
- liger_kernel_nightly-0.3.1.dev20241101044713.dist-info/RECORD,,
43
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
44
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/LICENSE-Apache-2.0,sha256=NRaCIsL9eblGS35gk4WKTC0usNYnR_mgRHJTKqz2_UE,11348
45
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/LICENSE-MIT-AutoAWQ,sha256=pfiOyInrAPY3xQbvV1i-gOqNZK7QEyIepT1IbqOYYYo,1067
46
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/LICENSE-MIT-Efficient-Cross-Entropy,sha256=PaC9HqyFYTy-ClS0H8Zfa2motJuTppjECXmjHwJcaOk,1063
47
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/LICENSE-MIT-llmc,sha256=kyFLt_XUcXS88CuxQt5-PjOcLjpJP2m-T4gtqZf3GLc,1071
48
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/LICENSE-MIT-triton,sha256=wL6W8IwsKiyHtzXubg8TCXhRZuo8S83EPdqXffYtqWg,1131
49
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/METADATA,sha256=8To-7_aoXsEfzk2d1hYaVSCwdirZDFJmavscRZLumXs,27717
50
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
51
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
52
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
53
+ liger_kernel_nightly-0.3.1.dev20241102065152.dist-info/RECORD,,