liger-kernel-nightly 0.5.6.dev20250411210855__py3-none-any.whl → 0.5.6.dev20250412004725__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -222,7 +222,7 @@ def lce_forward(
222
222
  lm_head_weight=self.lm_head.weight,
223
223
  labels=labels,
224
224
  hidden_size=self.config.hidden_size,
225
- softcap=self.config.final_logit_softcapping,
225
+ final_logit_softcapping=self.config.final_logit_softcapping,
226
226
  **loss_kwargs,
227
227
  )
228
228
 
@@ -112,7 +112,7 @@ def causal_forward(
112
112
  lm_head_weight=self.lm_head.weight,
113
113
  labels=labels,
114
114
  hidden_size=self.config.hidden_size,
115
- softcap=self.config.final_logit_softcapping,
115
+ final_logit_softcapping=self.config.final_logit_softcapping,
116
116
  **loss_kwargs,
117
117
  )
118
118
 
@@ -1,14 +1,18 @@
1
+ from typing import Optional
2
+
3
+ import torch
1
4
  import torch.nn as nn
2
5
 
3
6
  import liger_kernel.transformers.functional as F
4
7
 
5
8
 
6
9
  def fixed_fused_linear_cross_entropy(
7
- hidden_states,
8
- lm_head_weight,
9
- target,
10
- num_items_in_batch: int = None,
10
+ hidden_states: torch.Tensor,
11
+ lm_head_weight: torch.Tensor,
12
+ target: torch.Tensor,
13
+ num_items_in_batch: Optional[int] = None,
11
14
  ignore_index: int = -100,
15
+ final_logit_softcapping: Optional[float] = None,
12
16
  **kwargs,
13
17
  ):
14
18
  reduction = "sum" if num_items_in_batch is not None else "mean"
@@ -18,7 +22,7 @@ def fixed_fused_linear_cross_entropy(
18
22
  target,
19
23
  reduction=reduction,
20
24
  ignore_index=ignore_index,
21
- **kwargs,
25
+ softcap=final_logit_softcapping,
22
26
  )
23
27
  if reduction == "sum":
24
28
  loss = loss / num_items_in_batch
@@ -31,15 +35,17 @@ def LigerForCausalLMLoss(
31
35
  lm_head_weight,
32
36
  labels,
33
37
  hidden_size: int,
34
- num_items_in_batch: int = None,
38
+ num_items_in_batch: Optional[int] = None,
35
39
  ignore_index: int = -100,
40
+ shift_labels: Optional[torch.Tensor] = None,
41
+ final_logit_softcapping: Optional[float] = None,
36
42
  **kwargs,
37
43
  ):
38
44
  # Skip upcast since intermediate values for the loss are all fp32 in kernel
39
- labels = labels.to(hidden_states.device)
40
- # Shift so that token < n predict n
41
- labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
42
- shift_labels = labels[..., 1:].contiguous()
45
+ if shift_labels is None:
46
+ # Shift so that token < n predict n
47
+ labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
48
+ shift_labels = labels[..., 1:].contiguous()
43
49
 
44
50
  # Flatten the tokens
45
51
  hidden_states = hidden_states.view(-1, hidden_size)
@@ -52,6 +58,7 @@ def LigerForCausalLMLoss(
52
58
  shift_labels,
53
59
  num_items_in_batch,
54
60
  ignore_index,
61
+ final_logit_softcapping,
55
62
  **kwargs,
56
63
  )
57
64
  return loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.6.dev20250411210855
3
+ Version: 0.5.6.dev20250412004725
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -56,11 +56,11 @@ liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_J
56
56
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
57
57
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
58
  liger_kernel/transformers/model/gemma.py,sha256=-JoHKWjtYPpxHQa6QbCwnzX_cctRZG2ZTsaUv-dmOt4,9816
59
- liger_kernel/transformers/model/gemma2.py,sha256=tLl1v-O8K0NZ7BQcSf1dE3450-xV72RAk4E5oTPcu_s,10907
60
- liger_kernel/transformers/model/gemma3.py,sha256=PjAfFtupT9EW0sb57Hx8UJXcnvq9HFgNndeAE4EqyPw,16086
59
+ liger_kernel/transformers/model/gemma2.py,sha256=n4MZupFGDMvtnvkvkNhRrxXS3ZF341BVfyLjrOXp10g,10923
60
+ liger_kernel/transformers/model/gemma3.py,sha256=ge3JYchiKvX1G1Zp00jX2zmQK2K7ymJoZAxbb2ggslw,16102
61
61
  liger_kernel/transformers/model/llama.py,sha256=UVXQLRW7rCU5vPab54dLNS3ER37eM446peHX00Yz6eA,10493
62
62
  liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
63
- liger_kernel/transformers/model/loss_utils.py,sha256=Z-fUrf-cUDUjUIH7Tl9OL2hT8nmtx7ES3kg8syuWKy4,1476
63
+ liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
64
64
  liger_kernel/transformers/model/mistral.py,sha256=RacuKcckuDK6oSraCGD0R0bm-fE0K3q-lkYaAC56C2E,5481
65
65
  liger_kernel/transformers/model/mixtral.py,sha256=gLcqGabdv1XnuciS9b-TpkTDnGL8K32Hoq9j2vZMBRY,11502
66
66
  liger_kernel/transformers/model/mllama.py,sha256=75mxtmMsNd_q8KlKeawj2uMP6v2KjDuUi4nsUKM5jqA,11308
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
74
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/METADATA,sha256=mX6Na52mRBO2g2I7Qqj34QGM17tMQAZLNjE7XX0g9fA,23297
79
- liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
- liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
- liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/RECORD,,
77
+ liger_kernel_nightly-0.5.6.dev20250412004725.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel_nightly-0.5.6.dev20250412004725.dist-info/METADATA,sha256=NlVGUs75aShCOyLfsS9C_shIfJcUe7B_JIqZIvr1K3I,23297
79
+ liger_kernel_nightly-0.5.6.dev20250412004725.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
+ liger_kernel_nightly-0.5.6.dev20250412004725.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
+ liger_kernel_nightly-0.5.6.dev20250412004725.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel_nightly-0.5.6.dev20250412004725.dist-info/RECORD,,