liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (57) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +31 -33
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/rope.py +2 -2
  46. liger_kernel/transformers/swiglu.py +2 -8
  47. liger_kernel/transformers/trainer/__init__.py +1 -3
  48. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  49. liger_kernel/triton/__init__.py +1 -3
  50. liger_kernel/triton/monkey_patch.py +1 -3
  51. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
  52. liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
  53. liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD +0 -66
  54. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
1
1
  import inspect
2
2
  import logging
3
+
3
4
  from functools import partial
4
5
  from typing import Callable
5
6
 
6
7
  import transformers
8
+
7
9
  from packaging import version
8
10
  from transformers import PreTrainedModel
9
11
 
@@ -12,38 +14,24 @@ from liger_kernel.transformers.functional import liger_cross_entropy
12
14
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
13
15
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
14
16
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
15
- from liger_kernel.transformers.model.gemma import (
16
- lce_forward_deprecated as gemma_lce_forward_deprecated,
17
- )
17
+ from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
18
18
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
- from liger_kernel.transformers.model.gemma2 import (
20
- lce_forward_deprecated as gemma2_lce_forward_deprected,
21
- )
19
+ from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
22
20
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
23
- from liger_kernel.transformers.model.llama import (
24
- lce_forward_deprecated as llama_lce_forward_deprecated,
25
- )
21
+ from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
26
22
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
27
23
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
28
- from liger_kernel.transformers.model.mixtral import (
29
- lce_forward_deprecated as mixtral_lce_forward_deprecated,
30
- )
24
+ from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
31
25
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
32
- from liger_kernel.transformers.model.phi3 import (
33
- lce_forward_deprecated as phi3_lce_forward_deprecated,
34
- )
26
+ from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
35
27
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
36
- from liger_kernel.transformers.model.qwen2 import (
37
- lce_forward_deprecated as qwen2_lce_forward_deprecated,
38
- )
28
+ from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
39
29
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
40
30
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
41
31
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
42
- from liger_kernel.transformers.swiglu import (
43
- LigerBlockSparseTop2MLP,
44
- LigerPhi3SwiGLUMLP,
45
- LigerSwiGLUMLP,
46
- )
32
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
33
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
34
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
47
35
 
48
36
  transformer_version = version.parse(transformers.__version__)
49
37
 
@@ -57,23 +45,17 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
57
45
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
58
46
 
59
47
 
60
- def _patch_rms_norm_module(
61
- module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
62
- ):
48
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
63
49
  module.offset = offset
64
50
  module.casting_mode = casting_mode
65
- module.variance_epsilon = (
66
- getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
67
- )
51
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
68
52
  module.in_place = in_place
69
53
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
70
54
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
71
55
 
72
56
 
73
57
  def _patch_layer_norm_module(module, eps=1e-6):
74
- module.variance_epsilon = (
75
- getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
76
- )
58
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
77
59
  module.hidden_size = module.normalized_shape
78
60
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
79
61
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
@@ -145,9 +127,7 @@ def apply_liger_kernel_to_llama(
145
127
 
146
128
  for decoder_layer in base_model.layers:
147
129
  if swiglu:
148
- _bind_method_to_module(
149
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
150
- )
130
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
151
131
  if rms_norm:
152
132
  _patch_rms_norm_module(decoder_layer.input_layernorm)
153
133
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -184,17 +164,13 @@ def apply_liger_kernel_to_mllama(
184
164
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
185
165
 
186
166
  from transformers.models.mllama import modeling_mllama
187
- from transformers.models.mllama.modeling_mllama import (
188
- MllamaForCausalLM,
189
- MllamaForConditionalGeneration,
190
- MllamaTextModel,
191
- MllamaVisionModel,
192
- )
167
+ from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
168
+ from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
169
+ from transformers.models.mllama.modeling_mllama import MllamaTextModel
170
+ from transformers.models.mllama.modeling_mllama import MllamaVisionModel
193
171
 
194
172
  from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
195
- from liger_kernel.transformers.model.mllama import (
196
- lce_forward_deprecated as mllama_lce_forward_deprecated,
197
- )
173
+ from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
198
174
 
199
175
  if rope:
200
176
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -241,9 +217,7 @@ def apply_liger_kernel_to_mllama(
241
217
  _patch_rms_norm_module(text_model.norm)
242
218
  for decoder_layer in text_model.layers:
243
219
  if swiglu:
244
- _bind_method_to_module(
245
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
246
- )
220
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
247
221
  if rms_norm:
248
222
  _patch_rms_norm_module(decoder_layer.input_layernorm)
249
223
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -317,9 +291,7 @@ def apply_liger_kernel_to_mistral(
317
291
 
318
292
  for decoder_layer in base_model.layers:
319
293
  if swiglu:
320
- _bind_method_to_module(
321
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
322
- )
294
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
323
295
  if rms_norm:
324
296
  _patch_rms_norm_module(decoder_layer.input_layernorm)
325
297
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -391,9 +363,7 @@ def apply_liger_kernel_to_mixtral(
391
363
  for decoder_layer in base_model.layers:
392
364
  if swiglu:
393
365
  for expert in decoder_layer.block_sparse_moe.experts:
394
- _bind_method_to_module(
395
- expert, "forward", LigerBlockSparseTop2MLP.forward
396
- )
366
+ _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
397
367
  if rms_norm:
398
368
  _patch_rms_norm_module(decoder_layer.input_layernorm)
399
369
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -431,12 +401,8 @@ def apply_liger_kernel_to_gemma(
431
401
  from transformers.models.gemma.modeling_gemma import GemmaModel
432
402
 
433
403
  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
434
- LigerRMSNormForGemma = partial(
435
- LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
436
- )
437
- _patch_rms_norm_module_for_gemma = partial(
438
- _patch_rms_norm_module, casting_mode="gemma", offset=1.0
439
- )
404
+ LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
405
+ _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
440
406
 
441
407
  if rope:
442
408
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -471,9 +437,7 @@ def apply_liger_kernel_to_gemma(
471
437
 
472
438
  for decoder_layer in base_model.layers:
473
439
  if geglu:
474
- _bind_method_to_module(
475
- decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
476
- )
440
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
477
441
  if rms_norm:
478
442
  _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
479
443
  _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
@@ -510,9 +474,7 @@ def apply_liger_kernel_to_gemma2(
510
474
  from transformers.models.gemma2 import modeling_gemma2
511
475
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
512
476
 
513
- LigerRMSNormForGemma2 = partial(
514
- LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
515
- )
477
+ LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
516
478
  _patch_rms_norm_module_for_gemma2 = partial(
517
479
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
518
480
  )
@@ -551,20 +513,12 @@ def apply_liger_kernel_to_gemma2(
551
513
 
552
514
  for decoder_layer in base_model.layers:
553
515
  if geglu:
554
- _bind_method_to_module(
555
- decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
556
- )
516
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
557
517
  if rms_norm:
558
518
  _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
559
- _patch_rms_norm_module_for_gemma2(
560
- decoder_layer.post_attention_layernorm
561
- )
562
- _patch_rms_norm_module_for_gemma2(
563
- decoder_layer.pre_feedforward_layernorm
564
- )
565
- _patch_rms_norm_module_for_gemma2(
566
- decoder_layer.post_feedforward_layernorm
567
- )
519
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
520
+ _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
521
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
568
522
 
569
523
 
570
524
  def apply_liger_kernel_to_qwen2(
@@ -633,9 +587,7 @@ def apply_liger_kernel_to_qwen2(
633
587
 
634
588
  for decoder_layer in base_model.layers:
635
589
  if swiglu:
636
- _bind_method_to_module(
637
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
638
- )
590
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
639
591
  if rms_norm:
640
592
  _patch_rms_norm_module(decoder_layer.input_layernorm)
641
593
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -674,14 +626,10 @@ def apply_liger_kernel_to_qwen2_vl(
674
626
  from transformers.models.qwen2_vl import modeling_qwen2_vl
675
627
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
676
628
 
677
- from liger_kernel.transformers.model.qwen2_vl import (
678
- lce_forward as qwen2_vl_lce_forward,
679
- )
629
+ from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
680
630
 
681
631
  if rope:
682
- modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = (
683
- liger_multimodal_rotary_pos_emb
684
- )
632
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
685
633
  if rms_norm:
686
634
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
687
635
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
@@ -712,9 +660,7 @@ def apply_liger_kernel_to_qwen2_vl(
712
660
  _patch_rms_norm_module(base_model.norm)
713
661
  for decoder_layer in base_model.layers:
714
662
  if swiglu:
715
- _bind_method_to_module(
716
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
717
- )
663
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
718
664
  if rms_norm:
719
665
  _patch_rms_norm_module(decoder_layer.input_layernorm)
720
666
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -783,9 +729,7 @@ def apply_liger_kernel_to_phi3(
783
729
 
784
730
  for decoder_layer in base_model.layers:
785
731
  if swiglu:
786
- _bind_method_to_module(
787
- decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
788
- )
732
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
789
733
  if rms_norm:
790
734
  _patch_rms_norm_module(decoder_layer.input_layernorm)
791
735
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -826,24 +770,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
826
770
  return
827
771
 
828
772
  if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
829
- logger.info(
830
- f"There are currently no Liger kernels supported for model type: {model_type}."
831
- )
773
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
832
774
  return
833
775
 
834
776
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
835
777
  apply_fn_signature = inspect.signature(apply_fn)
836
778
 
837
779
  # Filter out the keyword arguments that are not supported by the apply function
838
- applicable_kwargs = {
839
- key: value
840
- for key, value in kwargs.items()
841
- if key in apply_fn_signature.parameters
842
- }
780
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
843
781
 
844
- logger.info(
845
- f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
846
- )
782
+ logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
847
783
 
848
784
  # Assume this is invoked pre-model initialization, so we only need to patch transformers code
849
785
  apply_fn(**applicable_kwargs)
@@ -857,20 +793,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
857
793
  - model: the model instance to apply Liger kernels to
858
794
  - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
859
795
  """
860
- model_type = getattr(model, "config", None) and getattr(
861
- model.config, "model_type", None
862
- )
796
+ model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
863
797
 
864
798
  if not model_type:
865
- logger.info(
866
- "Model type could not be determined from model config. No Liger kernels will be applied."
867
- )
799
+ logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
868
800
  return
869
801
 
870
802
  if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
871
- logger.info(
872
- f"There are currently no Liger kernels supported for model type: {model_type}."
873
- )
803
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
874
804
  return
875
805
 
876
806
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
@@ -878,11 +808,7 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
878
808
  apply_fn_signature = inspect.signature(apply_fn)
879
809
 
880
810
  # Filter out the keyword arguments that are not supported by the apply function
881
- applicable_kwargs = {
882
- key: value
883
- for key, value in kwargs.items()
884
- if key in apply_fn_signature.parameters
885
- }
811
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
886
812
  logger.info(
887
813
  f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
888
814
  )
@@ -19,9 +19,7 @@ class LigerRMSNorm(nn.Module):
19
19
  "ones",
20
20
  "zeros",
21
21
  ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
22
- self.weight = nn.Parameter(
23
- torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
24
- )
22
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
25
23
  self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
26
24
  eps,
27
25
  offset,
@@ -40,4 +38,6 @@ class LigerRMSNorm(nn.Module):
40
38
  )
41
39
 
42
40
  def extra_repr(self):
43
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
41
+ return (
42
+ f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
43
+ )
@@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
8
8
  Args:
9
9
  q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
10
  k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
- cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
12
- sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
13
13
  position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
14
14
  unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
15
 
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
16
16
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
17
17
 
18
18
  def forward(self, x):
19
-
20
- return self.down_proj(
21
- LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
22
- )
19
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
23
20
 
24
21
 
25
22
  class LigerBlockSparseTop2MLP(nn.Module):
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
36
33
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
37
34
 
38
35
  def forward(self, x):
39
-
40
36
  return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
41
37
 
42
38
 
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
51
47
  self.config = config
52
48
  self.hidden_size = config.hidden_size
53
49
  self.intermediate_size = config.intermediate_size
54
- self.gate_up_proj = nn.Linear(
55
- self.hidden_size, 2 * self.intermediate_size, bias=False
56
- )
50
+ self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
57
51
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
58
52
  if config.hidden_act not in ["silu", "swish"]:
59
53
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
@@ -1,6 +1,4 @@
1
1
  try:
2
- from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401
3
- LigerORPOTrainer,
4
- )
2
+ from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
5
3
  except ImportError:
6
4
  raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
@@ -1,7 +1,14 @@
1
- from typing import Any, Callable, Dict, List, Literal, Tuple, Union
1
+ from typing import Any
2
+ from typing import Callable
3
+ from typing import Dict
4
+ from typing import List
5
+ from typing import Literal
6
+ from typing import Tuple
7
+ from typing import Union
2
8
 
3
9
  import torch
4
10
  import torch.nn as nn
11
+
5
12
  from torch.distributed.fsdp import FullyShardedDataParallel
6
13
  from trl.trainer import ORPOTrainer
7
14
 
@@ -62,9 +69,7 @@ class _FSDPForwardRedirection:
62
69
  class LigerORPOTrainer(ORPOTrainer):
63
70
  def concatenated_forward(
64
71
  self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
65
- ) -> Tuple[
66
- torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
67
- ]:
72
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
68
73
  """
69
74
  Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
70
75
  We do this to avoid doing two forward passes, because it's faster for FSDP.
@@ -79,9 +84,7 @@ class LigerORPOTrainer(ORPOTrainer):
79
84
 
80
85
  model_kwargs = (
81
86
  {
82
- "decoder_input_ids": self._shift_right(
83
- concatenated_batch["concatenated_labels"]
84
- ),
87
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
85
88
  }
86
89
  if self.is_encoder_decoder
87
90
  else {}
@@ -109,14 +112,10 @@ class LigerORPOTrainer(ORPOTrainer):
109
112
  **model_kwargs,
110
113
  )
111
114
 
112
- orpo_loss_fn = LigerFusedLinearORPOLoss(
113
- ignore_index=self.label_pad_token_id, beta=self.beta
114
- )
115
+ orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
115
116
 
116
117
  def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
117
- return orpo_loss_fn(
118
- lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias
119
- )
118
+ return orpo_loss_fn(lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias)
120
119
 
121
120
  orpo_loss, aux_outputs = _FSDPForwardRedirection()(
122
121
  model,
@@ -149,9 +148,7 @@ class LigerORPOTrainer(ORPOTrainer):
149
148
  ) = aux_outputs[:5]
150
149
 
151
150
  # return loss, metrics
152
- chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[
153
- 5:
154
- ]
151
+ chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[5:]
155
152
 
156
153
  reward_accuracies = (chosen_rewards > rejected_rewards).float()
157
154
 
@@ -1,3 +1 @@
1
- from liger_kernel.triton.monkey_patch import ( # noqa: F401
2
- apply_liger_triton_cache_manager,
3
- )
1
+ from liger_kernel.triton.monkey_patch import apply_liger_triton_cache_manager # noqa: F401
@@ -37,6 +37,4 @@ def apply_liger_triton_cache_manager():
37
37
  Experimental feature to get around transient FileNotFoundError in triton compilation.
38
38
  For more details please see https://github.com/triton-lang/triton/pull/4295
39
39
  """
40
- os.environ["TRITON_CACHE_MANAGER"] = (
41
- "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
42
- )
40
+ os.environ["TRITON_CACHE_MANAGER"] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241223032015
3
+ Version: 0.5.2.dev20241223042135
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -0,0 +1,66 @@
1
+ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
3
+ liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
+ liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
5
+ liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
6
+ liger_kernel/chunked_loss/cpo_loss.py,sha256=H-BU2QC5GzNQ4NnTM6TLgwvo-Eoh5YAE-Q_j1dX_w0g,3517
7
+ liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
8
+ liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
+ liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=M-QWvGPnWefYDn6Hr9bPn7diMNP5qrUaeWTb_zdMO4E,10265
10
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
11
+ liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
12
+ liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
13
+ liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ liger_kernel/ops/cross_entropy.py,sha256=2OPIkSXeQAIfSCODYK45Jf8xrz7HoGqFHr1MHS_pijE,15895
15
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=ObNZjgYlCvigbgKl-FAjHAvk90wiwJ-4Wrf8JUHmlLQ,9346
16
+ liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
17
+ liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
18
+ liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
19
+ liger_kernel/ops/jsd.py,sha256=WwGY9ozuH3PMg3udRI6H96UqAEzIozJoO2HtHg7010M,6107
20
+ liger_kernel/ops/kl_div.py,sha256=MnfuYqqQESON1X2Swy064x1urKtMFdgeSWd60VttBXI,8420
21
+ liger_kernel/ops/layer_norm.py,sha256=quvt2zcwcJCDxrgm-iWoHzDYOoeZdMC76nZ_ckw6-p8,7640
22
+ liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
23
+ liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
24
+ liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
25
+ liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
26
+ liger_kernel/ops/utils.py,sha256=vMWxfcw02xUvjpEXQQ3Rrj68ddZ8Of3hiOmEFq1zSKg,3852
27
+ liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
28
+ liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
29
+ liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
30
+ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
31
+ liger_kernel/transformers/cross_entropy.py,sha256=s5-ZM1NBMDjG-KKJKBtIkmArj1jCUjDnpL-2QKhKYho,1734
32
+ liger_kernel/transformers/functional.py,sha256=hxReSBDEUZkOnZgURD8sf6ETYvf9yqCOOMU2k9Ywh90,4435
33
+ liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=K4tfpoNPUJpWv7rCHEcs5xhJLg5td8GcpJrAryF5NMk,1451
34
+ liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
35
+ liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
36
+ liger_kernel/transformers/group_norm.py,sha256=URmjkQFsrbMffzcJiGpX7ckxWlpL95AiJS-80hwAWPk,2173
37
+ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
38
+ liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
39
+ liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
40
+ liger_kernel/transformers/monkey_patch.py,sha256=6eXmtERKr4YUppRAaH7a_ml3AOz0ao68E8QnOyXtIkY,37794
41
+ liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
42
+ liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
43
+ liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
44
+ liger_kernel/transformers/swiglu.py,sha256=i9WTqcNRqReU4XJs391IPbl-I5X0wG4T72D4pqGFfJg,2422
45
+ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
46
+ liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
47
+ liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
+ liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
49
+ liger_kernel/transformers/model/gemma2.py,sha256=27NcoZjEqP9Lqb4Wf0EKqTbr2HTGiHPhrVyPCRGPz6s,10767
50
+ liger_kernel/transformers/model/llama.py,sha256=3LJFXKFDKvEakaWPc_NicSFst4Y_hdSMrdl1UDK1EcA,10330
51
+ liger_kernel/transformers/model/mistral.py,sha256=MVRksI5_j_8WJu8znOHKCdSI5jSu-S7cdFYzt9m_vIQ,5180
52
+ liger_kernel/transformers/model/mixtral.py,sha256=jpZJkpl625Q-JHWarj2MqT5mRaSsiCtg0c9vVyvOdCY,11430
53
+ liger_kernel/transformers/model/mllama.py,sha256=qWexBdskuN3gPJvPUwt4J0nU675tGD6W7wxgRZ9Bifg,11145
54
+ liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UKU4uk8Up8pU,10292
55
+ liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
56
+ liger_kernel/transformers/model/qwen2_vl.py,sha256=rZg3nU3YgF6wkB1UJ0a9IACSIlVOSCyLltyqw951MQQ,8609
57
+ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
58
+ liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
59
+ liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
+ liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
+ liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/METADATA,sha256=diXsKJ9zCLk-w9SCZLWWx-xN0ZP8-W51KrgpISmaxn4,21055
63
+ liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD,,
@@ -1,66 +0,0 @@
1
- liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- liger_kernel/env_report.py,sha256=ok9PMXtO-8uLj_feCJI4h9hz2NtolZ2AG_OJTW5qmo4,1823
3
- liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
- liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
5
- liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
6
- liger_kernel/chunked_loss/cpo_loss.py,sha256=3PdSp1gju1u0ffFGpUufbZPIva8aI3SW1TfqkJOpw1g,3554
7
- liger_kernel/chunked_loss/dpo_loss.py,sha256=jbTno1pKEc-HxAGFY3NSycBzdWyTacyRCzH3FhrMUMo,4383
8
- liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
- liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
10
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vvratrj8rba8NaGbO2ffbUfWMVEvDMxDCo6SI8nCtbo,16376
11
- liger_kernel/chunked_loss/orpo_loss.py,sha256=xHsKjlCWQVew7_hhpyUp3a1wd0tdpgx-zQAezNjk3Q4,3532
12
- liger_kernel/chunked_loss/simpo_loss.py,sha256=_5gXIkEAT0Kt_AufziQlYhBjzDJVSQVk7oSDHcrw1xw,3759
13
- liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- liger_kernel/ops/cross_entropy.py,sha256=3oPrw6KzIVc11gSyfdrLnj0WJB4qOYjE1tC8HJeFFpg,15888
15
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=Tnw4gyAYVVdnCOqhOuLEzbUQ3goOTnoAfk3pqSIM5ac,9301
16
- liger_kernel/ops/fused_linear_jsd.py,sha256=nOv4zwfxHqqepKEmMsQuz-B3H-gRjyo8uClpmqSGLYA,9693
17
- liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
18
- liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d8,10961
19
- liger_kernel/ops/jsd.py,sha256=Ap2b0_geCl6fqBXLI1IS6Yn6GlO-8LgPmnOW3y47dus,6151
20
- liger_kernel/ops/kl_div.py,sha256=vBz1ieu_sPcFbgG_wL0SwrbSQ6xVDK51_FNo-yf7CjY,8430
21
- liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
22
- liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
23
- liger_kernel/ops/rms_norm.py,sha256=bleuRC9IS_P3zEX07b0LZ_cpgeTH8l5sdvkelucpRgM,11792
24
- liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
25
- liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
26
- liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
27
- liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
28
- liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
29
- liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
30
- liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
31
- liger_kernel/transformers/cross_entropy.py,sha256=yEm_YQ7oa3_BzT3hdW6KrAslduhSqWcJQVNZZDcWCg4,1758
32
- liger_kernel/transformers/functional.py,sha256=sUBoU8Vb4pLpr9G6IdkRsToYgh-rCXL4OLYat7Tv_GU,4450
33
- liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=_i0PXSp5iZ9pKXdEeZ4lvHCENJYjV4y74yz3ZRG5XQg,1484
34
- liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
35
- liger_kernel/transformers/geglu.py,sha256=QcrME_8ooIn0xa59LaC0aoOdRrBIFd11Y0bAyF0NfCw,1130
36
- liger_kernel/transformers/group_norm.py,sha256=FJ9R7mS9G1wO-GRIQ6QKSmIhnZ6nQ6GIkE4NnX_hnn0,2241
37
- liger_kernel/transformers/jsd.py,sha256=sbr8DnKSYZJH9pv2rpmboNijYGpZKbhb2-WSGp5_v6g,3001
38
- liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
39
- liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
40
- liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
41
- liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
42
- liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
43
- liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
44
- liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
45
- liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
46
- liger_kernel/transformers/experimental/embedding.py,sha256=HpckiAMKM8-SRxKDcGTqortVxnjhwpZsfsp9lfjqfeM,895
47
- liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
- liger_kernel/transformers/model/gemma.py,sha256=R4huxuR48gkLrdT8KqV7As2v9dZtEmcGVz6YG1ZmuJE,9692
49
- liger_kernel/transformers/model/gemma2.py,sha256=zxQsxCRqkoxCES3GJPVI7soUuF3J5HZDlvJgaBos1zM,10836
50
- liger_kernel/transformers/model/llama.py,sha256=RinsgC_eR-YNvZd2SHPQxZ4eyR3uViaTFCM3SvI5nks,10426
51
- liger_kernel/transformers/model/mistral.py,sha256=XpL1rlWg_llvW3z_Hf_d8WQs7uQaH4ds7EZ2SxjQHsU,5144
52
- liger_kernel/transformers/model/mixtral.py,sha256=JlNS6DA6SJqeHDk7j2LZymPQ3wngrTIo3wUGFBqHuJs,11504
53
- liger_kernel/transformers/model/mllama.py,sha256=mesNCgj0Ea1O-fqRD4LVxDJ1CR2abY_zAzK_bfVzkiU,11222
54
- liger_kernel/transformers/model/phi3.py,sha256=xUZPlaPKwknLjHc3uUW3EPodm1h0vD3G7Qnhh51v-Io,10332
55
- liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5PBO3q0MoCs00,9619
56
- liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
57
- liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBbzGWILfaowUR1hmRw,210
58
- liger_kernel/transformers/trainer/orpo_trainer.py,sha256=O2k2vdHl-O1S-U61aEmyUFu3QrEuNAipQa2oUBb3HAA,7679
59
- liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
60
- liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
61
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/METADATA,sha256=glSPMysElXhTUr1u74GrG_xjFSIek9GtE9AlPR6GkLs,21055
63
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD,,