liger-kernel-nightly 0.5.10.dev20250702150221__py3-none-any.whl → 0.5.10.dev20250704061125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
liger_kernel/ops/geglu.py CHANGED
@@ -40,7 +40,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
40
40
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
41
  tanh_result = tanh(tanh_arg)
42
42
  geglu_a = 0.5 * a_row * (1 + tanh_result)
43
- c_row = geglu_a * b_row
43
+ c_row = geglu_a.cast(b_row.dtype) * b_row
44
44
  tl.store(c + col_offsets, c_row, mask=mask)
45
45
 
46
46
 
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
26
26
  # sigmoid requires type float32
27
27
  a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
28
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
- c_row = silu(a_row) * b_row
29
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
30
30
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
31
 
32
32
 
@@ -27,6 +27,7 @@ def lce_forward_deprecated(
27
27
  output_hidden_states: Optional[bool] = None,
28
28
  return_dict: Optional[bool] = None,
29
29
  cache_position: Optional[torch.LongTensor] = None,
30
+ skip_logits: Optional[bool] = None,
30
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
31
32
  r"""
32
33
 
@@ -81,7 +82,14 @@ def lce_forward_deprecated(
81
82
  loss = None
82
83
  logits = None
83
84
 
84
- if self.training and (labels is not None):
85
+ if skip_logits and labels is None:
86
+ raise ValueError("skip_logits is True, but labels is None")
87
+
88
+ if skip_logits is None:
89
+ # By default, if in training mode, don't materialize logits
90
+ skip_logits = self.training and labels is not None
91
+
92
+ if skip_logits:
85
93
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
86
94
  shift_labels = labels[..., 1:].contiguous()
87
95
 
@@ -30,6 +30,7 @@ def lce_forward_deprecated(
30
30
  output_hidden_states: Optional[bool] = None,
31
31
  return_dict: Optional[bool] = None,
32
32
  cache_position: Optional[torch.LongTensor] = None,
33
+ skip_logits: Optional[bool] = None,
33
34
  **kwargs,
34
35
  ) -> Union[Tuple, CausalLMOutputWithPast]:
35
36
  r"""
@@ -85,7 +86,14 @@ def lce_forward_deprecated(
85
86
  loss = None
86
87
  logits = None
87
88
 
88
- if self.training and (labels is not None):
89
+ if skip_logits and labels is None:
90
+ raise ValueError("skip_logits is True, but labels is None")
91
+
92
+ if skip_logits is None:
93
+ # By default, if in training mode, don't materialize logits
94
+ skip_logits = self.training and labels is not None
95
+
96
+ if skip_logits:
89
97
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
90
98
  shift_labels = labels[..., 1:].contiguous()
91
99
 
@@ -37,6 +37,7 @@ def lce_forward_deprecated(
37
37
  output_hidden_states: Optional[bool] = None,
38
38
  return_dict: Optional[bool] = None,
39
39
  cache_position: Optional[torch.LongTensor] = None,
40
+ skip_logits: Optional[bool] = None,
40
41
  ) -> Union[Tuple, CausalLMOutputWithPast]:
41
42
  r"""
42
43
  Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
@@ -91,7 +92,15 @@ def lce_forward_deprecated(
91
92
  loss = None
92
93
  logits = None
93
94
 
94
- if self.training and (labels is not None):
95
+ # if in training mode, don't materialize logits
96
+ if skip_logits and labels is None:
97
+ raise ValueError("skip_logits is True, but labels is None")
98
+
99
+ if skip_logits is None:
100
+ # By default, if in training mode, don't materialize logits
101
+ skip_logits = self.training and labels is not None
102
+
103
+ if skip_logits:
95
104
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
96
105
  shift_labels = labels[..., 1:].contiguous()
97
106
 
@@ -133,6 +133,3 @@ def lce_forward(
133
133
  hidden_states=outputs.hidden_states,
134
134
  attentions=outputs.attentions,
135
135
  )
136
-
137
-
138
- # Note: Grad Acc is not fixed in mistral at transformer 4.46.1
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
26
26
  output_hidden_states: Optional[bool] = None,
27
27
  return_dict: Optional[bool] = None,
28
28
  cache_position: Optional[torch.LongTensor] = None,
29
+ skip_logits: Optional[bool] = None,
29
30
  ) -> Union[Tuple, CausalLMOutputWithPast]:
30
31
  r"""
31
32
  Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
@@ -80,7 +81,14 @@ def lce_forward_deprecated(
80
81
  loss = None
81
82
  logits = None
82
83
 
83
- if self.training and labels is not None:
84
+ if skip_logits and labels is None:
85
+ raise ValueError("skip_logits is True, but labels is None")
86
+
87
+ if skip_logits is None:
88
+ # By default, if in training mode, don't materialize logits
89
+ skip_logits = self.training and labels is not None
90
+
91
+ if skip_logits:
84
92
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
85
93
  shift_labels = labels[..., 1:].contiguous()
86
94
 
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
26
26
  output_hidden_states: Optional[bool] = None,
27
27
  return_dict: Optional[bool] = None,
28
28
  cache_position: Optional[torch.LongTensor] = None,
29
+ skip_logits: Optional[bool] = None,
29
30
  ) -> Union[Tuple, CausalLMOutputWithPast]:
30
31
  r"""
31
32
  Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -80,6 +81,13 @@ def lce_forward_deprecated(
80
81
  loss = None
81
82
  logits = None
82
83
 
84
+ if skip_logits and labels is None:
85
+ raise ValueError("skip_logits is True, but labels is None")
86
+
87
+ if skip_logits is None:
88
+ # By default, if in training mode, don't materialize logits
89
+ skip_logits = self.training and labels is not None
90
+
83
91
  if self.training and (labels is not None):
84
92
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
85
93
  shift_labels = labels[..., 1:].contiguous()
@@ -611,10 +611,17 @@ def apply_liger_kernel_to_mistral(
611
611
  if cross_entropy:
612
612
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
613
613
  if fused_linear_cross_entropy:
614
- if model is not None:
615
- model.forward = MethodType(mistral_lce_forward, model)
614
+ if transformer_version >= version.parse("4.49.0"):
615
+ if model is not None:
616
+ model.forward = MethodType(mistral_lce_forward, model)
617
+ else:
618
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
616
619
  else:
617
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
620
+ logger.warning(
621
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
622
+ )
623
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
624
+
618
625
  if swiglu:
619
626
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
620
627
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250702150221
3
+ Version: 0.5.10.dev20250704061125
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -22,7 +22,7 @@ liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
22
22
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
23
23
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
24
24
  liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
25
- liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
25
+ liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
26
26
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
27
27
  liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
28
28
  liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
@@ -34,7 +34,7 @@ liger_kernel/ops/rms_norm.py,sha256=-rcgHwWCxlA-Syec2XhdW4jfOeCDt2r7qwjslgXFYDU,
34
34
  liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
35
35
  liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
36
36
  liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
37
- liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
37
+ liger_kernel/ops/swiglu.py,sha256=D7nd4u_LInwsIRNCDdY77lqnTz8-W5dJrpEAt8zEO_A,3033
38
38
  liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
39
39
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
40
40
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
@@ -54,7 +54,7 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
54
54
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
55
55
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
56
56
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
57
- liger_kernel/transformers/monkey_patch.py,sha256=YkX0LT6lISg3UTqFjjt9kTr36WgiHvYTQObAS1_Bmi4,85172
57
+ liger_kernel/transformers/monkey_patch.py,sha256=rXmaVry8hdpnH8HunfJhZmrsdlwAxjMP3x10ZYMnTy4,85554
58
58
  liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
59
59
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
60
60
  liger_kernel/transformers/rms_norm.py,sha256=vkekcvTeWY8vL4H6hg3t0XeY0Ew_3OFMPHuzqlxPPVw,2719
@@ -66,21 +66,21 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
66
66
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
67
67
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
68
68
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
- liger_kernel/transformers/model/gemma.py,sha256=gvP-9zZ1e-DQD06qltWmRhiJClJDtkMQL1UrPMMZZGQ,9730
70
- liger_kernel/transformers/model/gemma2.py,sha256=ORmzklEAMpk93nToRo4d_ZJbM4ScVE2szczsEL4hw7w,11019
69
+ liger_kernel/transformers/model/gemma.py,sha256=mNX-mIwV6jI4zfbrUHp0C468pOmjzsL7mjXipGt-eS0,10007
70
+ liger_kernel/transformers/model/gemma2.py,sha256=R_JFPyWTk7RyA7D05ZiIaNO5pX8gWcvfWf-6rdCRMxs,11296
71
71
  liger_kernel/transformers/model/gemma3.py,sha256=JI4jj9K660HeRsofB6cpkCHBQ0OsazElArRtKUehUmw,15945
72
72
  liger_kernel/transformers/model/glm4.py,sha256=GlnEhdGJuDIqp2R9qC54biY3HwV1tWmfpJm6ijoAsrM,5257
73
- liger_kernel/transformers/model/llama.py,sha256=LcIxVfF0PXXWHBVJa6Ody_5fAtIpxQcI4jC_j-o51fU,12503
73
+ liger_kernel/transformers/model/llama.py,sha256=i8jJgyZsMKWQ-zKloETLugtwFpUOdaWxLDceciFXKd4,12832
74
74
  liger_kernel/transformers/model/llama4.py,sha256=IgbB8sTh3dlETQnaNNy1bZLuXy-Nt7qmeAjF27ydGpg,4210
75
75
  liger_kernel/transformers/model/llava.py,sha256=bLCioday_SOm69ogMDBhy_4UsVkH2-BSl93-EXY6-7I,15076
76
76
  liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
77
- liger_kernel/transformers/model/mistral.py,sha256=okKkyashfFLfhjIT--f3JY6JHOslOtDI8U1dlpBC2Zs,5565
77
+ liger_kernel/transformers/model/mistral.py,sha256=syYNL8dLThX2-4uC13Lu0krEZ5zw3InviDUR3AJmc-I,5500
78
78
  liger_kernel/transformers/model/mixtral.py,sha256=VY-y73IyjcCyWyI7ahxXLw0fJrhgjYfr1xwRYtsHX0o,11396
79
79
  liger_kernel/transformers/model/mllama.py,sha256=my29NXk-p6ckQaP8qDIN8e318yI_9mQZHt38MV3SqLY,11280
80
80
  liger_kernel/transformers/model/olmo2.py,sha256=6L_bo-ZUgO1lYppdJneOtYxNIylQKS6BiGp13g7Uq9E,5259
81
81
  liger_kernel/transformers/model/paligemma.py,sha256=xuIx3oOwTgftU3jqLfWOxUxgCLBNJh0yNC21an9qDjo,18773
82
- liger_kernel/transformers/model/phi3.py,sha256=m-MD_OuTaYMGZhHOvl-RHOVEObrL8tL5cBv3VTNd4F0,10376
83
- liger_kernel/transformers/model/qwen2.py,sha256=SdN7V-MI3eX9s2DAFRvC1g-G146uG_5n1fnNdY9QwYk,9658
82
+ liger_kernel/transformers/model/phi3.py,sha256=zAzBVNOA16B16yy2HWsEgOMHhLoYkpWOWPgBT4z95WI,10655
83
+ liger_kernel/transformers/model/qwen2.py,sha256=3fpOTEOkniQmkCfN1KUa3KhseHJVzhj2Ht9FdYPUy-E,9962
84
84
  liger_kernel/transformers/model/qwen2_5_vl.py,sha256=zEVVwotCXnAm3RRc8-1Nc8uitSWrwW4B9dYY2uOZDwg,6331
85
85
  liger_kernel/transformers/model/qwen2_vl.py,sha256=5vK-vtCDpKZ2w33xYp2BS8kQYWUbKMqaiKvQcI27Mss,5884
86
86
  liger_kernel/transformers/model/qwen3.py,sha256=w2jBHuK9kK9EmOr5dnEIXNQXUgUSV_sJUkXSEwxLPHs,4885
@@ -89,9 +89,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
89
89
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
90
90
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
91
91
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
92
- liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
93
- liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/METADATA,sha256=CoPcolC_DjZu7v28Cqy2kQoE65U6f5Rx1EKf55y9NxU,24536
94
- liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
95
- liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
96
- liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
97
- liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/RECORD,,
92
+ liger_kernel_nightly-0.5.10.dev20250704061125.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
93
+ liger_kernel_nightly-0.5.10.dev20250704061125.dist-info/METADATA,sha256=7mx4Zgy5kdvnanl50nrzJ9HE6vTou5oeeOLx45V_T1c,24536
94
+ liger_kernel_nightly-0.5.10.dev20250704061125.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
95
+ liger_kernel_nightly-0.5.10.dev20250704061125.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
96
+ liger_kernel_nightly-0.5.10.dev20250704061125.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
97
+ liger_kernel_nightly-0.5.10.dev20250704061125.dist-info/RECORD,,