liger-kernel-nightly 0.6.2.dev20250823034010__py3-none-any.whl → 0.6.2.dev20250830153327__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -101,8 +101,21 @@ def fused_linear_cross_entropy_forward(
101
101
  # Compute softmax to get predicted probabilities
102
102
  probs = torch.softmax(logits_for_softmax, dim=-1)
103
103
 
104
- # Get the predicted probability for each target token
105
- pred_probs = torch.gather(probs, -1, target_chunk.unsqueeze(-1)).squeeze(-1)
104
+ # Get predicted probabilities for token scaling, handling ignored targets
105
+ valid_target_mask = target_chunk != ignore_index
106
+ valid_targets = target_chunk[valid_target_mask]
107
+
108
+ if len(valid_targets) > 0:
109
+ # Gather probabilities only for valid targets
110
+ valid_probs = probs[valid_target_mask]
111
+ pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
112
+
113
+ # Create full tensor with zeros for ignored targets
114
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
115
+ pred_probs[valid_target_mask] = pred_probs_valid
116
+ else:
117
+ # All targets are ignored
118
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
106
119
 
107
120
  # Store the scaling factors
108
121
  scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
@@ -25,6 +25,7 @@ def fixed_fused_linear_cross_entropy(
25
25
  ignore_index=ignore_index,
26
26
  softcap=final_logit_softcapping,
27
27
  accum_dtype=accum_dtype,
28
+ **kwargs,
28
29
  )
29
30
  if reduction == "sum":
30
31
  loss = loss / num_items_in_batch
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.2.dev20250823034010
3
+ Version: 0.6.2.dev20250830153327
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -20,7 +20,7 @@ liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
20
20
  liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCRpQKM,19123
21
21
  liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
22
22
  liger_kernel/ops/fused_add_rms_norm.py,sha256=UBqmlqFCmhSAIpkNKd8rrfXatX7Z4J9bp2dX9A0lrJQ,14017
23
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=AIlKMOnM3J7ZeAgPP1uvA3T4OIeRkz6TTr_Lg9XgZGY,13581
23
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=6rB3pdwU97Ivl2IHndPJjzhP28E9Fd0pUQcPHLiuCjc,14290
24
24
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
25
25
  liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
26
26
  liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
@@ -79,7 +79,7 @@ liger_kernel/transformers/model/glm4v.py,sha256=zbV3agptEYpGAD0eeCRwIpJAhJUviTT5
79
79
  liger_kernel/transformers/model/llama.py,sha256=i8jJgyZsMKWQ-zKloETLugtwFpUOdaWxLDceciFXKd4,12832
80
80
  liger_kernel/transformers/model/llama4.py,sha256=IgbB8sTh3dlETQnaNNy1bZLuXy-Nt7qmeAjF27ydGpg,4210
81
81
  liger_kernel/transformers/model/llava.py,sha256=bLCioday_SOm69ogMDBhy_4UsVkH2-BSl93-EXY6-7I,15076
82
- liger_kernel/transformers/model/loss_utils.py,sha256=YiYsmRHIuoRnFjGpwyIM18DCsrPPmO32YWMWqkEm1UQ,1867
82
+ liger_kernel/transformers/model/loss_utils.py,sha256=02RVkPI7Qs4ZP4yU_udCAvD_2hgIaHmxremRKe3N7EE,1885
83
83
  liger_kernel/transformers/model/mistral.py,sha256=syYNL8dLThX2-4uC13Lu0krEZ5zw3InviDUR3AJmc-I,5500
84
84
  liger_kernel/transformers/model/mixtral.py,sha256=VY-y73IyjcCyWyI7ahxXLw0fJrhgjYfr1xwRYtsHX0o,11396
85
85
  liger_kernel/transformers/model/mllama.py,sha256=NhJtlXiuszJHo5YSJOvSGYH47ly7Hse8r-5BKznBg9s,11522
@@ -96,9 +96,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
96
96
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
97
97
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
98
98
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
99
- liger_kernel_nightly-0.6.2.dev20250823034010.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
100
- liger_kernel_nightly-0.6.2.dev20250823034010.dist-info/METADATA,sha256=j0Am_cC3_HVR7UGo0hYJ19e7U6oO7VXKFm7ykn5atNU,24504
101
- liger_kernel_nightly-0.6.2.dev20250823034010.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
102
- liger_kernel_nightly-0.6.2.dev20250823034010.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
103
- liger_kernel_nightly-0.6.2.dev20250823034010.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
104
- liger_kernel_nightly-0.6.2.dev20250823034010.dist-info/RECORD,,
99
+ liger_kernel_nightly-0.6.2.dev20250830153327.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
100
+ liger_kernel_nightly-0.6.2.dev20250830153327.dist-info/METADATA,sha256=fla7v4BScWkzdeXCOwxI_b7g7kJti3IyQA9BwmlN8GM,24504
101
+ liger_kernel_nightly-0.6.2.dev20250830153327.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
102
+ liger_kernel_nightly-0.6.2.dev20250830153327.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
103
+ liger_kernel_nightly-0.6.2.dev20250830153327.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
104
+ liger_kernel_nightly-0.6.2.dev20250830153327.dist-info/RECORD,,