liger-kernel-nightly 0.4.0.dev20241109010649__py3-none-any.whl → 0.4.0.dev20241112204954__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

@@ -4,11 +4,13 @@ import sys
4
4
 
5
5
  def print_env_report():
6
6
  """
7
+
7
8
  Prints a report of the environment. Useful for debugging and reproducibility.
8
9
  Usage:
9
10
  ```
10
11
  python -m liger_kernel.env_report
11
12
  ```
13
+
12
14
  """
13
15
  print("Environment Report:")
14
16
  print("-------------------")
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
2
4
  from liger_kernel.ops.fused_linear_cross_entropy import (
3
5
  LigerFusedLinearCrossEntropyFunction,
@@ -13,7 +15,6 @@ from liger_kernel.ops.rope import LigerRopeFunction
13
15
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
14
16
 
15
17
  liger_swiglu = LigerSiLUMulFunction.apply
16
- liger_cross_entropy = LigerCrossEntropyFunction.apply
17
18
  liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
18
19
  liger_geglu = LigerGELUMulFunction.apply
19
20
  liger_rms_norm = LigerRMSNormFunction.apply
@@ -23,3 +24,33 @@ liger_kl_div = LigerKLDivLossFunction.apply
23
24
  liger_jsd = LigerJSDFunction.apply
24
25
  liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
25
26
  liger_group_norm = LigerGroupNormFunction.apply
27
+
28
+
29
+ # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
30
+ # `weight` and `size_average` are placeholders and not implemented yet
31
+ def liger_cross_entropy(
32
+ input,
33
+ target,
34
+ weight=None,
35
+ size_average=None,
36
+ ignore_index: int = -100,
37
+ reduce=None,
38
+ reduction: str = "mean",
39
+ label_smoothing: float = 0.0,
40
+ lse_square_scale: float = 0.0,
41
+ softcap: Optional[float] = None,
42
+ return_z_loss: bool = False,
43
+ ):
44
+ loss, z_loss = LigerCrossEntropyFunction.apply(
45
+ input,
46
+ target,
47
+ ignore_index,
48
+ lse_square_scale,
49
+ label_smoothing,
50
+ reduction,
51
+ softcap,
52
+ return_z_loss,
53
+ )
54
+ if not return_z_loss:
55
+ return loss
56
+ return loss, z_loss
@@ -8,6 +8,7 @@ from packaging import version
8
8
  from transformers import PreTrainedModel
9
9
 
10
10
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
11
+ from liger_kernel.transformers.functional import liger_cross_entropy
11
12
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
12
13
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
13
14
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
@@ -111,8 +112,16 @@ def apply_liger_kernel_to_llama(
111
112
  modeling_llama.LlamaRMSNorm = LigerRMSNorm
112
113
  if swiglu:
113
114
  modeling_llama.LlamaMLP = LigerSwiGLUMLP
115
+
114
116
  if cross_entropy:
115
- modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
117
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
118
+ from transformers.loss.loss_utils import nn
119
+
120
+ nn.functional.cross_entropy = liger_cross_entropy
121
+ else:
122
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
123
+ modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
124
+
116
125
  if fused_linear_cross_entropy:
117
126
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
118
127
  modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
@@ -192,7 +201,13 @@ def apply_liger_kernel_to_mllama(
192
201
  if swiglu:
193
202
  modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
194
203
  if cross_entropy:
195
- modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
204
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
205
+ from transformers.loss.loss_utils import nn
206
+
207
+ nn.functional.cross_entropy = liger_cross_entropy
208
+ else:
209
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
210
+ modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
196
211
  if fused_linear_cross_entropy:
197
212
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
198
213
  modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
@@ -342,7 +357,14 @@ def apply_liger_kernel_to_mixtral(
342
357
  if rms_norm:
343
358
  modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
344
359
  if cross_entropy:
345
- modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
360
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
361
+ from transformers.loss.loss_utils import nn
362
+
363
+ nn.functional.cross_entropy = liger_cross_entropy
364
+ else:
365
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
366
+ modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
367
+
346
368
  if fused_linear_cross_entropy:
347
369
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
348
370
  modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
@@ -417,7 +439,13 @@ def apply_liger_kernel_to_gemma(
417
439
  if rms_norm:
418
440
  modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
419
441
  if cross_entropy:
420
- modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
442
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
443
+ from transformers.loss.loss_utils import nn
444
+
445
+ nn.functional.cross_entropy = liger_cross_entropy
446
+ else:
447
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
448
+ modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
421
449
  if geglu:
422
450
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
423
451
  if fused_linear_cross_entropy:
@@ -474,6 +502,7 @@ def apply_liger_kernel_to_gemma2(
474
502
  assert not (
475
503
  cross_entropy and fused_linear_cross_entropy
476
504
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
505
+
477
506
  from transformers.models.gemma2 import modeling_gemma2
478
507
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
479
508
 
@@ -490,7 +519,13 @@ def apply_liger_kernel_to_gemma2(
490
519
  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
491
520
  modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
492
521
  if cross_entropy:
493
- modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
522
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
523
+ from transformers.loss.loss_utils import nn
524
+
525
+ nn.functional.cross_entropy = liger_cross_entropy
526
+ else:
527
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
528
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
494
529
  if fused_linear_cross_entropy:
495
530
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
496
531
  modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
@@ -562,8 +597,15 @@ def apply_liger_kernel_to_qwen2(
562
597
  modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
563
598
  if rms_norm:
564
599
  modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
600
+
565
601
  if cross_entropy:
566
- modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
602
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
603
+ from transformers.loss.loss_utils import nn
604
+
605
+ nn.functional.cross_entropy = liger_cross_entropy
606
+ else:
607
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
608
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
567
609
 
568
610
  # import pdb; pdb.set_trace()
569
611
  if fused_linear_cross_entropy:
@@ -710,7 +752,13 @@ def apply_liger_kernel_to_phi3(
710
752
  if swiglu:
711
753
  modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
712
754
  if cross_entropy:
713
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
755
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
756
+ from transformers.loss.loss_utils import nn
757
+
758
+ nn.functional.cross_entropy = liger_cross_entropy
759
+ else:
760
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
761
+ modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
714
762
  if fused_linear_cross_entropy:
715
763
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
716
764
  modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.0.dev20241109010649
3
+ Version: 0.4.0.dev20241112204954
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -426,3 +426,4 @@ Biblatex entry:
426
426
  ↑ Back to Top ↑
427
427
  </a>
428
428
  </p>
429
+
@@ -1,4 +1,4 @@
1
- liger_kernel/env_report.py,sha256=LFUJ6UMkFFGPBYXBlqHFGy4bhsemEpSI-_1edSazlHI,1130
1
+ liger_kernel/env_report.py,sha256=jye8RvUkmhqaIshdeIpoUABoAu7FPKJUib4FnAfvkpw,1132
2
2
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  liger_kernel/ops/cross_entropy.py,sha256=sfUb7-jIZp0EKXjg1DYy2Wdzw_Mg-mHmGoR5bpdm4tw,15526
4
4
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=JPiQ0TgPjtQ-3F5ovC0b5ZnBk067XUmzyNuGO3KZv44,9963
@@ -17,7 +17,7 @@ liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh
17
17
  liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
18
18
  liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
19
19
  liger_kernel/transformers/cross_entropy.py,sha256=yEm_YQ7oa3_BzT3hdW6KrAslduhSqWcJQVNZZDcWCg4,1758
20
- liger_kernel/transformers/functional.py,sha256=Wgv6vJtLdGa8xJoXEKeJz-QhkICZuq6DNmnXhoAJR04,1235
20
+ liger_kernel/transformers/functional.py,sha256=Hd4WvxNqOJHM9HmRfAQueRnmOy5WU9nFsFygB5Iv8Xs,2000
21
21
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=_i0PXSp5iZ9pKXdEeZ4lvHCENJYjV4y74yz3ZRG5XQg,1484
22
22
  liger_kernel/transformers/fused_linear_jsd.py,sha256=MJ-KjmLZnakuoVpnbDGkd95DQgvESniyrRWYzollVZM,4066
23
23
  liger_kernel/transformers/geglu.py,sha256=QcrME_8ooIn0xa59LaC0aoOdRrBIFd11Y0bAyF0NfCw,1130
@@ -25,7 +25,7 @@ liger_kernel/transformers/group_norm.py,sha256=FJ9R7mS9G1wO-GRIQ6QKSmIhnZ6nQ6GIk
25
25
  liger_kernel/transformers/jsd.py,sha256=W-5CypO2mx4-bUWOxq1KScfCdoXlLoYbtt5xBnRzMs4,3056
26
26
  liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
27
27
  liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
28
- liger_kernel/transformers/monkey_patch.py,sha256=Qh1AuF9Rt2zHzlcyNYaRIVIvN2nZLEpTrYV3K7ECDbk,36171
28
+ liger_kernel/transformers/monkey_patch.py,sha256=RSnHwStyOi5C4xBbHrKvo32X2hUU7InjSyBzb4iu9T0,38184
29
29
  liger_kernel/transformers/rms_norm.py,sha256=4XfMQI6dORF7s_5qUqVHKWv-3IUomaimU2dg-NwnpoM,1035
30
30
  liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
31
31
  liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
@@ -43,9 +43,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
43
43
  liger_kernel/transformers/model/qwen2_vl.py,sha256=j6xAhp9AG195dsZK5f8dFYVM9uKtWApZrggT5Y08jn4,7055
44
44
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
45
45
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
46
- liger_kernel_nightly-0.4.0.dev20241109010649.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
47
- liger_kernel_nightly-0.4.0.dev20241109010649.dist-info/METADATA,sha256=Ym9n-6VBrfKW1NClTcrCH1YbsWtpz-YW4WtaVhf4TxM,28095
48
- liger_kernel_nightly-0.4.0.dev20241109010649.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
49
- liger_kernel_nightly-0.4.0.dev20241109010649.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
50
- liger_kernel_nightly-0.4.0.dev20241109010649.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
51
- liger_kernel_nightly-0.4.0.dev20241109010649.dist-info/RECORD,,
46
+ liger_kernel_nightly-0.4.0.dev20241112204954.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
47
+ liger_kernel_nightly-0.4.0.dev20241112204954.dist-info/METADATA,sha256=oxSZh5RfYz52ri89yD3KJGjE1lX_Y1d_Dvi-j5aq12Y,28096
48
+ liger_kernel_nightly-0.4.0.dev20241112204954.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
49
+ liger_kernel_nightly-0.4.0.dev20241112204954.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
50
+ liger_kernel_nightly-0.4.0.dev20241112204954.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
51
+ liger_kernel_nightly-0.4.0.dev20241112204954.dist-info/RECORD,,