liger-kernel-nightly 0.5.5.dev20250324181221__py3-none-any.whl → 0.5.5.dev20250327235249__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

@@ -112,6 +112,21 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
112
112
  compute_ce_loss=compute_ce_loss,
113
113
  )
114
114
 
115
+ # If the teacher and student token size is different, pad student logits to match the teacher's.
116
+ # This only applies to cases where they share exactly the same vocab and tokenizer just
117
+ # that teacher logit is padded for some training efficiency such as
118
+ # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
119
+ teacher_vocab_size = teacher_weight.shape[0]
120
+ student_vocab_size = student_weight.shape[0]
121
+ if teacher_vocab_size > student_vocab_size:
122
+ pad_size = teacher_vocab_size - student_vocab_size
123
+ pad_tensor = torch.zeros(
124
+ (*student_logits_chunk.shape[:-1], pad_size),
125
+ dtype=student_logits_chunk.dtype,
126
+ device=student_logits_chunk.device
127
+ )
128
+ student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
129
+
115
130
  student_logits_chunk /= temperature
116
131
  teacher_logits_chunk /= temperature
117
132
 
@@ -52,6 +52,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
52
52
  module.in_place = in_place
53
53
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
54
54
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
55
+ module.__class__.__name__ = LigerRMSNorm.__name__
55
56
 
56
57
 
57
58
  def _patch_layer_norm_module(module, eps=1e-6):
@@ -59,6 +60,17 @@ def _patch_layer_norm_module(module, eps=1e-6):
59
60
  module.hidden_size = module.normalized_shape
60
61
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
61
62
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
63
+ module.__class__.__name__ = LigerLayerNorm.__name__
64
+
65
+
66
+ def _patch_swiglu_module(module, liger_module):
67
+ _bind_method_to_module(module, "forward", liger_module.forward)
68
+ module.__class__.__name__ = liger_module.__name__
69
+
70
+
71
+ def _patch_geglu_module(module):
72
+ _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
73
+ module.__class__.__name__ = LigerGEGLUMLP.__name__
62
74
 
63
75
 
64
76
  def apply_liger_kernel_to_granite(
@@ -134,7 +146,7 @@ def apply_liger_kernel_to_granite(
134
146
 
135
147
  for decoder_layer in base_model.layers:
136
148
  if swiglu:
137
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
149
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
138
150
  if rms_norm:
139
151
  _patch_rms_norm_module(decoder_layer.input_layernorm)
140
152
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -206,7 +218,7 @@ def apply_liger_kernel_to_llama(
206
218
 
207
219
  for decoder_layer in base_model.layers:
208
220
  if swiglu:
209
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
221
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
210
222
  if rms_norm:
211
223
  _patch_rms_norm_module(decoder_layer.input_layernorm)
212
224
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -296,7 +308,7 @@ def apply_liger_kernel_to_mllama(
296
308
  _patch_rms_norm_module(text_model.norm)
297
309
  for decoder_layer in text_model.layers:
298
310
  if swiglu:
299
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
311
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
300
312
  if rms_norm:
301
313
  _patch_rms_norm_module(decoder_layer.input_layernorm)
302
314
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -370,7 +382,7 @@ def apply_liger_kernel_to_mistral(
370
382
 
371
383
  for decoder_layer in base_model.layers:
372
384
  if swiglu:
373
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
385
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
374
386
  if rms_norm:
375
387
  _patch_rms_norm_module(decoder_layer.input_layernorm)
376
388
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -442,7 +454,7 @@ def apply_liger_kernel_to_mixtral(
442
454
  for decoder_layer in base_model.layers:
443
455
  if swiglu:
444
456
  for expert in decoder_layer.block_sparse_moe.experts:
445
- _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
457
+ _patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
446
458
  if rms_norm:
447
459
  _patch_rms_norm_module(decoder_layer.input_layernorm)
448
460
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -516,7 +528,7 @@ def apply_liger_kernel_to_gemma(
516
528
 
517
529
  for decoder_layer in base_model.layers:
518
530
  if geglu:
519
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
531
+ _patch_geglu_module(decoder_layer.mlp)
520
532
  if rms_norm:
521
533
  _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
522
534
  _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
@@ -592,7 +604,7 @@ def apply_liger_kernel_to_gemma2(
592
604
 
593
605
  for decoder_layer in base_model.layers:
594
606
  if geglu:
595
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
607
+ _patch_geglu_module(decoder_layer.mlp)
596
608
  if rms_norm:
597
609
  _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
598
610
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
@@ -776,7 +788,7 @@ def apply_liger_kernel_to_qwen2(
776
788
 
777
789
  for decoder_layer in base_model.layers:
778
790
  if swiglu:
779
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
791
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
780
792
  if rms_norm:
781
793
  _patch_rms_norm_module(decoder_layer.input_layernorm)
782
794
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -849,7 +861,7 @@ def apply_liger_kernel_to_qwen2_vl(
849
861
  _patch_rms_norm_module(base_model.norm)
850
862
  for decoder_layer in base_model.layers:
851
863
  if swiglu:
852
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
864
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
853
865
  if rms_norm:
854
866
  _patch_rms_norm_module(decoder_layer.input_layernorm)
855
867
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -916,7 +928,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
916
928
  _patch_rms_norm_module(base_model.norm)
917
929
  for decoder_layer in base_model.layers:
918
930
  if swiglu:
919
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
931
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
920
932
  if rms_norm:
921
933
  _patch_rms_norm_module(decoder_layer.input_layernorm)
922
934
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -985,7 +997,7 @@ def apply_liger_kernel_to_phi3(
985
997
 
986
998
  for decoder_layer in base_model.layers:
987
999
  if swiglu:
988
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
1000
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
989
1001
  if rms_norm:
990
1002
  _patch_rms_norm_module(decoder_layer.input_layernorm)
991
1003
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -1048,7 +1060,7 @@ def apply_liger_kernel_to_olmo2(
1048
1060
 
1049
1061
  for decoder_layer in base_model.layers:
1050
1062
  if swiglu:
1051
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1063
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1052
1064
  if rms_norm:
1053
1065
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1054
1066
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250324181221
3
+ Version: 0.5.5.dev20250327235249
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -6,7 +6,7 @@ liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIu
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
8
8
  liger_kernel/chunked_loss/functional.py,sha256=THWWpCnRVhTVfnPnyvQjdBvo1JDtxhwLmtZE_yiBBqM,817
9
- liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=oeZhRw87UUo01UotfaMxDhWa7Xr6IERmK3zzF1CQqEc,11037
9
+ liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=y7e2mF_6HGMNNuoWAmJ8Y5bK-hRUe2q4-R6r7lf-Mw8,11934
10
10
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
11
11
  liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=wGujqwLz91mOE9MmdenhBIKvbmswhwtINMCpcP7D74c,9050
12
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
@@ -45,7 +45,7 @@ liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD
45
45
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
46
46
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
47
47
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
48
- liger_kernel/transformers/monkey_patch.py,sha256=qRCgchODu6AuO8la6uAnrDEA-sSP9ADt8IOp4kl-Dd0,52053
48
+ liger_kernel/transformers/monkey_patch.py,sha256=_-4oMqEq5mQCSWQ7PaNI9cbLdT_UPPobYaqboa1oN4I,52210
49
49
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
50
50
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
51
51
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -71,9 +71,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
71
71
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
72
72
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
73
73
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
74
- liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
75
- liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/METADATA,sha256=NyKmdw6KevABFKKrqEdmIf8agklqARr8azTzS4RRx0k,22959
76
- liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
77
- liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
78
- liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
79
- liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/RECORD,,
74
+ liger_kernel_nightly-0.5.5.dev20250327235249.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
75
+ liger_kernel_nightly-0.5.5.dev20250327235249.dist-info/METADATA,sha256=4_bQ76AZvAHUe6dzZt_JTtxjAX7_UV6O5zLmi7RNmK4,22959
76
+ liger_kernel_nightly-0.5.5.dev20250327235249.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
77
+ liger_kernel_nightly-0.5.5.dev20250327235249.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
78
+ liger_kernel_nightly-0.5.5.dev20250327235249.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
79
+ liger_kernel_nightly-0.5.5.dev20250327235249.dist-info/RECORD,,