liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +5 -11
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
|
3
4
|
from functools import partial
|
4
5
|
from typing import Callable
|
5
6
|
|
6
7
|
import transformers
|
8
|
+
|
7
9
|
from packaging import version
|
8
10
|
from transformers import PreTrainedModel
|
9
11
|
|
@@ -12,38 +14,24 @@ from liger_kernel.transformers.functional import liger_cross_entropy
|
|
12
14
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
13
15
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
14
16
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
15
|
-
from liger_kernel.transformers.model.gemma import
|
16
|
-
lce_forward_deprecated as gemma_lce_forward_deprecated,
|
17
|
-
)
|
17
|
+
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
18
18
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
19
|
-
from liger_kernel.transformers.model.gemma2 import
|
20
|
-
lce_forward_deprecated as gemma2_lce_forward_deprected,
|
21
|
-
)
|
19
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
22
20
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
23
|
-
from liger_kernel.transformers.model.llama import
|
24
|
-
lce_forward_deprecated as llama_lce_forward_deprecated,
|
25
|
-
)
|
21
|
+
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
26
22
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
27
23
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
28
|
-
from liger_kernel.transformers.model.mixtral import
|
29
|
-
lce_forward_deprecated as mixtral_lce_forward_deprecated,
|
30
|
-
)
|
24
|
+
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
31
25
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
32
|
-
from liger_kernel.transformers.model.phi3 import
|
33
|
-
lce_forward_deprecated as phi3_lce_forward_deprecated,
|
34
|
-
)
|
26
|
+
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
35
27
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
36
|
-
from liger_kernel.transformers.model.qwen2 import
|
37
|
-
lce_forward_deprecated as qwen2_lce_forward_deprecated,
|
38
|
-
)
|
28
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
39
29
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
40
30
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
41
31
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
42
|
-
from liger_kernel.transformers.swiglu import
|
43
|
-
|
44
|
-
|
45
|
-
LigerSwiGLUMLP,
|
46
|
-
)
|
32
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
33
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
34
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
47
35
|
|
48
36
|
transformer_version = version.parse(transformers.__version__)
|
49
37
|
|
@@ -57,23 +45,17 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
57
45
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
58
46
|
|
59
47
|
|
60
|
-
def _patch_rms_norm_module(
|
61
|
-
module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
|
62
|
-
):
|
48
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
63
49
|
module.offset = offset
|
64
50
|
module.casting_mode = casting_mode
|
65
|
-
module.variance_epsilon = (
|
66
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
67
|
-
)
|
51
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
68
52
|
module.in_place = in_place
|
69
53
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
70
54
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
71
55
|
|
72
56
|
|
73
57
|
def _patch_layer_norm_module(module, eps=1e-6):
|
74
|
-
module.variance_epsilon = (
|
75
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
76
|
-
)
|
58
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
77
59
|
module.hidden_size = module.normalized_shape
|
78
60
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
79
61
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
@@ -145,9 +127,7 @@ def apply_liger_kernel_to_llama(
|
|
145
127
|
|
146
128
|
for decoder_layer in base_model.layers:
|
147
129
|
if swiglu:
|
148
|
-
_bind_method_to_module(
|
149
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
150
|
-
)
|
130
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
151
131
|
if rms_norm:
|
152
132
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
153
133
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -184,17 +164,13 @@ def apply_liger_kernel_to_mllama(
|
|
184
164
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
185
165
|
|
186
166
|
from transformers.models.mllama import modeling_mllama
|
187
|
-
from transformers.models.mllama.modeling_mllama import
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
MllamaVisionModel,
|
192
|
-
)
|
167
|
+
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
|
168
|
+
from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
|
169
|
+
from transformers.models.mllama.modeling_mllama import MllamaTextModel
|
170
|
+
from transformers.models.mllama.modeling_mllama import MllamaVisionModel
|
193
171
|
|
194
172
|
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
|
195
|
-
from liger_kernel.transformers.model.mllama import
|
196
|
-
lce_forward_deprecated as mllama_lce_forward_deprecated,
|
197
|
-
)
|
173
|
+
from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
|
198
174
|
|
199
175
|
if rope:
|
200
176
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
@@ -241,9 +217,7 @@ def apply_liger_kernel_to_mllama(
|
|
241
217
|
_patch_rms_norm_module(text_model.norm)
|
242
218
|
for decoder_layer in text_model.layers:
|
243
219
|
if swiglu:
|
244
|
-
_bind_method_to_module(
|
245
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
246
|
-
)
|
220
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
247
221
|
if rms_norm:
|
248
222
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
249
223
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -317,9 +291,7 @@ def apply_liger_kernel_to_mistral(
|
|
317
291
|
|
318
292
|
for decoder_layer in base_model.layers:
|
319
293
|
if swiglu:
|
320
|
-
_bind_method_to_module(
|
321
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
322
|
-
)
|
294
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
323
295
|
if rms_norm:
|
324
296
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
325
297
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -391,9 +363,7 @@ def apply_liger_kernel_to_mixtral(
|
|
391
363
|
for decoder_layer in base_model.layers:
|
392
364
|
if swiglu:
|
393
365
|
for expert in decoder_layer.block_sparse_moe.experts:
|
394
|
-
_bind_method_to_module(
|
395
|
-
expert, "forward", LigerBlockSparseTop2MLP.forward
|
396
|
-
)
|
366
|
+
_bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
|
397
367
|
if rms_norm:
|
398
368
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
399
369
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -431,12 +401,8 @@ def apply_liger_kernel_to_gemma(
|
|
431
401
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
432
402
|
|
433
403
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
434
|
-
LigerRMSNormForGemma = partial(
|
435
|
-
|
436
|
-
)
|
437
|
-
_patch_rms_norm_module_for_gemma = partial(
|
438
|
-
_patch_rms_norm_module, casting_mode="gemma", offset=1.0
|
439
|
-
)
|
404
|
+
LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
|
405
|
+
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
440
406
|
|
441
407
|
if rope:
|
442
408
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
@@ -471,9 +437,7 @@ def apply_liger_kernel_to_gemma(
|
|
471
437
|
|
472
438
|
for decoder_layer in base_model.layers:
|
473
439
|
if geglu:
|
474
|
-
_bind_method_to_module(
|
475
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
476
|
-
)
|
440
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
477
441
|
if rms_norm:
|
478
442
|
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
479
443
|
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
@@ -510,9 +474,7 @@ def apply_liger_kernel_to_gemma2(
|
|
510
474
|
from transformers.models.gemma2 import modeling_gemma2
|
511
475
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
512
476
|
|
513
|
-
LigerRMSNormForGemma2 = partial(
|
514
|
-
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
|
515
|
-
)
|
477
|
+
LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
|
516
478
|
_patch_rms_norm_module_for_gemma2 = partial(
|
517
479
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
518
480
|
)
|
@@ -551,20 +513,12 @@ def apply_liger_kernel_to_gemma2(
|
|
551
513
|
|
552
514
|
for decoder_layer in base_model.layers:
|
553
515
|
if geglu:
|
554
|
-
_bind_method_to_module(
|
555
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
556
|
-
)
|
516
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
557
517
|
if rms_norm:
|
558
518
|
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
559
|
-
_patch_rms_norm_module_for_gemma2(
|
560
|
-
|
561
|
-
)
|
562
|
-
_patch_rms_norm_module_for_gemma2(
|
563
|
-
decoder_layer.pre_feedforward_layernorm
|
564
|
-
)
|
565
|
-
_patch_rms_norm_module_for_gemma2(
|
566
|
-
decoder_layer.post_feedforward_layernorm
|
567
|
-
)
|
519
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
520
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
|
521
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
568
522
|
|
569
523
|
|
570
524
|
def apply_liger_kernel_to_qwen2(
|
@@ -633,9 +587,7 @@ def apply_liger_kernel_to_qwen2(
|
|
633
587
|
|
634
588
|
for decoder_layer in base_model.layers:
|
635
589
|
if swiglu:
|
636
|
-
_bind_method_to_module(
|
637
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
638
|
-
)
|
590
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
639
591
|
if rms_norm:
|
640
592
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
641
593
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -674,14 +626,10 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
674
626
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
675
627
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
676
628
|
|
677
|
-
from liger_kernel.transformers.model.qwen2_vl import
|
678
|
-
lce_forward as qwen2_vl_lce_forward,
|
679
|
-
)
|
629
|
+
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
680
630
|
|
681
631
|
if rope:
|
682
|
-
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb =
|
683
|
-
liger_multimodal_rotary_pos_emb
|
684
|
-
)
|
632
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
685
633
|
if rms_norm:
|
686
634
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
687
635
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
@@ -712,9 +660,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
712
660
|
_patch_rms_norm_module(base_model.norm)
|
713
661
|
for decoder_layer in base_model.layers:
|
714
662
|
if swiglu:
|
715
|
-
_bind_method_to_module(
|
716
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
717
|
-
)
|
663
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
718
664
|
if rms_norm:
|
719
665
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
720
666
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -783,9 +729,7 @@ def apply_liger_kernel_to_phi3(
|
|
783
729
|
|
784
730
|
for decoder_layer in base_model.layers:
|
785
731
|
if swiglu:
|
786
|
-
_bind_method_to_module(
|
787
|
-
decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
|
788
|
-
)
|
732
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
|
789
733
|
if rms_norm:
|
790
734
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
791
735
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -826,24 +770,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
|
|
826
770
|
return
|
827
771
|
|
828
772
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
829
|
-
logger.info(
|
830
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
831
|
-
)
|
773
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
832
774
|
return
|
833
775
|
|
834
776
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
835
777
|
apply_fn_signature = inspect.signature(apply_fn)
|
836
778
|
|
837
779
|
# Filter out the keyword arguments that are not supported by the apply function
|
838
|
-
applicable_kwargs = {
|
839
|
-
key: value
|
840
|
-
for key, value in kwargs.items()
|
841
|
-
if key in apply_fn_signature.parameters
|
842
|
-
}
|
780
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
843
781
|
|
844
|
-
logger.info(
|
845
|
-
f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
|
846
|
-
)
|
782
|
+
logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
|
847
783
|
|
848
784
|
# Assume this is invoked pre-model initialization, so we only need to patch transformers code
|
849
785
|
apply_fn(**applicable_kwargs)
|
@@ -857,20 +793,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
857
793
|
- model: the model instance to apply Liger kernels to
|
858
794
|
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
859
795
|
"""
|
860
|
-
model_type = getattr(model, "config", None) and getattr(
|
861
|
-
model.config, "model_type", None
|
862
|
-
)
|
796
|
+
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
863
797
|
|
864
798
|
if not model_type:
|
865
|
-
logger.info(
|
866
|
-
"Model type could not be determined from model config. No Liger kernels will be applied."
|
867
|
-
)
|
799
|
+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
|
868
800
|
return
|
869
801
|
|
870
802
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
871
|
-
logger.info(
|
872
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
873
|
-
)
|
803
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
874
804
|
return
|
875
805
|
|
876
806
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
@@ -878,11 +808,7 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
878
808
|
apply_fn_signature = inspect.signature(apply_fn)
|
879
809
|
|
880
810
|
# Filter out the keyword arguments that are not supported by the apply function
|
881
|
-
applicable_kwargs = {
|
882
|
-
key: value
|
883
|
-
for key, value in kwargs.items()
|
884
|
-
if key in apply_fn_signature.parameters
|
885
|
-
}
|
811
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
886
812
|
logger.info(
|
887
813
|
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
888
814
|
)
|
@@ -19,9 +19,7 @@ class LigerRMSNorm(nn.Module):
|
|
19
19
|
"ones",
|
20
20
|
"zeros",
|
21
21
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
22
|
-
self.weight = nn.Parameter(
|
23
|
-
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
24
|
-
)
|
22
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
25
23
|
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
|
26
24
|
eps,
|
27
25
|
offset,
|
@@ -40,4 +38,6 @@ class LigerRMSNorm(nn.Module):
|
|
40
38
|
)
|
41
39
|
|
42
40
|
def extra_repr(self):
|
43
|
-
return
|
41
|
+
return (
|
42
|
+
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
|
43
|
+
)
|
@@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
8
8
|
Args:
|
9
9
|
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
|
10
10
|
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
|
11
|
-
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
|
12
|
-
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
|
11
|
+
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
12
|
+
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
13
13
|
position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
|
14
14
|
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
|
15
15
|
|
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
|
|
16
16
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
17
17
|
|
18
18
|
def forward(self, x):
|
19
|
-
|
20
|
-
return self.down_proj(
|
21
|
-
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
22
|
-
)
|
19
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
23
20
|
|
24
21
|
|
25
22
|
class LigerBlockSparseTop2MLP(nn.Module):
|
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
|
|
36
33
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
37
34
|
|
38
35
|
def forward(self, x):
|
39
|
-
|
40
36
|
return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
|
41
37
|
|
42
38
|
|
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
51
47
|
self.config = config
|
52
48
|
self.hidden_size = config.hidden_size
|
53
49
|
self.intermediate_size = config.intermediate_size
|
54
|
-
self.gate_up_proj = nn.Linear(
|
55
|
-
self.hidden_size, 2 * self.intermediate_size, bias=False
|
56
|
-
)
|
50
|
+
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
57
51
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
58
52
|
if config.hidden_act not in ["silu", "swish"]:
|
59
53
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
@@ -1,6 +1,4 @@
|
|
1
1
|
try:
|
2
|
-
from liger_kernel.transformers.trainer.orpo_trainer import
|
3
|
-
LigerORPOTrainer,
|
4
|
-
)
|
2
|
+
from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
|
5
3
|
except ImportError:
|
6
4
|
raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
|
@@ -1,7 +1,14 @@
|
|
1
|
-
from typing import Any
|
1
|
+
from typing import Any
|
2
|
+
from typing import Callable
|
3
|
+
from typing import Dict
|
4
|
+
from typing import List
|
5
|
+
from typing import Literal
|
6
|
+
from typing import Tuple
|
7
|
+
from typing import Union
|
2
8
|
|
3
9
|
import torch
|
4
10
|
import torch.nn as nn
|
11
|
+
|
5
12
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
6
13
|
from trl.trainer import ORPOTrainer
|
7
14
|
|
@@ -62,9 +69,7 @@ class _FSDPForwardRedirection:
|
|
62
69
|
class LigerORPOTrainer(ORPOTrainer):
|
63
70
|
def concatenated_forward(
|
64
71
|
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
65
|
-
) -> Tuple[
|
66
|
-
torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
|
67
|
-
]:
|
72
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
68
73
|
"""
|
69
74
|
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
70
75
|
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
@@ -79,9 +84,7 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
79
84
|
|
80
85
|
model_kwargs = (
|
81
86
|
{
|
82
|
-
"decoder_input_ids": self._shift_right(
|
83
|
-
concatenated_batch["concatenated_labels"]
|
84
|
-
),
|
87
|
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
85
88
|
}
|
86
89
|
if self.is_encoder_decoder
|
87
90
|
else {}
|
@@ -109,14 +112,10 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
109
112
|
**model_kwargs,
|
110
113
|
)
|
111
114
|
|
112
|
-
orpo_loss_fn = LigerFusedLinearORPOLoss(
|
113
|
-
ignore_index=self.label_pad_token_id, beta=self.beta
|
114
|
-
)
|
115
|
+
orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
|
115
116
|
|
116
117
|
def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
|
117
|
-
return orpo_loss_fn(
|
118
|
-
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias
|
119
|
-
)
|
118
|
+
return orpo_loss_fn(lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias)
|
120
119
|
|
121
120
|
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
|
122
121
|
model,
|
@@ -149,9 +148,7 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
149
148
|
) = aux_outputs[:5]
|
150
149
|
|
151
150
|
# return loss, metrics
|
152
|
-
chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[
|
153
|
-
5:
|
154
|
-
]
|
151
|
+
chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[5:]
|
155
152
|
|
156
153
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
157
154
|
|
liger_kernel/triton/__init__.py
CHANGED
@@ -37,6 +37,4 @@ def apply_liger_triton_cache_manager():
|
|
37
37
|
Experimental feature to get around transient FileNotFoundError in triton compilation.
|
38
38
|
For more details please see https://github.com/triton-lang/triton/pull/4295
|
39
39
|
"""
|
40
|
-
os.environ["TRITON_CACHE_MANAGER"] =
|
41
|
-
"liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
|
42
|
-
)
|
40
|
+
os.environ["TRITON_CACHE_MANAGER"] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
|
@@ -0,0 +1,66 @@
|
|
1
|
+
liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
|
3
|
+
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
|
+
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
|
+
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
|
+
liger_kernel/chunked_loss/cpo_loss.py,sha256=H-BU2QC5GzNQ4NnTM6TLgwvo-Eoh5YAE-Q_j1dX_w0g,3517
|
7
|
+
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
8
|
+
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
|
+
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=M-QWvGPnWefYDn6Hr9bPn7diMNP5qrUaeWTb_zdMO4E,10265
|
10
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
|
11
|
+
liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
|
12
|
+
liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
|
13
|
+
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
liger_kernel/ops/cross_entropy.py,sha256=2OPIkSXeQAIfSCODYK45Jf8xrz7HoGqFHr1MHS_pijE,15895
|
15
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=ObNZjgYlCvigbgKl-FAjHAvk90wiwJ-4Wrf8JUHmlLQ,9346
|
16
|
+
liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
|
17
|
+
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
18
|
+
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
19
|
+
liger_kernel/ops/jsd.py,sha256=WwGY9ozuH3PMg3udRI6H96UqAEzIozJoO2HtHg7010M,6107
|
20
|
+
liger_kernel/ops/kl_div.py,sha256=MnfuYqqQESON1X2Swy064x1urKtMFdgeSWd60VttBXI,8420
|
21
|
+
liger_kernel/ops/layer_norm.py,sha256=quvt2zcwcJCDxrgm-iWoHzDYOoeZdMC76nZ_ckw6-p8,7640
|
22
|
+
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
23
|
+
liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
|
24
|
+
liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
|
25
|
+
liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
|
26
|
+
liger_kernel/ops/utils.py,sha256=vMWxfcw02xUvjpEXQQ3Rrj68ddZ8Of3hiOmEFq1zSKg,3852
|
27
|
+
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
28
|
+
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
29
|
+
liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
|
30
|
+
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
31
|
+
liger_kernel/transformers/cross_entropy.py,sha256=s5-ZM1NBMDjG-KKJKBtIkmArj1jCUjDnpL-2QKhKYho,1734
|
32
|
+
liger_kernel/transformers/functional.py,sha256=hxReSBDEUZkOnZgURD8sf6ETYvf9yqCOOMU2k9Ywh90,4435
|
33
|
+
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=K4tfpoNPUJpWv7rCHEcs5xhJLg5td8GcpJrAryF5NMk,1451
|
34
|
+
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
35
|
+
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
36
|
+
liger_kernel/transformers/group_norm.py,sha256=URmjkQFsrbMffzcJiGpX7ckxWlpL95AiJS-80hwAWPk,2173
|
37
|
+
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
38
|
+
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
39
|
+
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
40
|
+
liger_kernel/transformers/monkey_patch.py,sha256=6eXmtERKr4YUppRAaH7a_ml3AOz0ao68E8QnOyXtIkY,37794
|
41
|
+
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
42
|
+
liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
|
43
|
+
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
44
|
+
liger_kernel/transformers/swiglu.py,sha256=i9WTqcNRqReU4XJs391IPbl-I5X0wG4T72D4pqGFfJg,2422
|
45
|
+
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
46
|
+
liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
|
47
|
+
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
|
+
liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
|
49
|
+
liger_kernel/transformers/model/gemma2.py,sha256=27NcoZjEqP9Lqb4Wf0EKqTbr2HTGiHPhrVyPCRGPz6s,10767
|
50
|
+
liger_kernel/transformers/model/llama.py,sha256=3LJFXKFDKvEakaWPc_NicSFst4Y_hdSMrdl1UDK1EcA,10330
|
51
|
+
liger_kernel/transformers/model/mistral.py,sha256=MVRksI5_j_8WJu8znOHKCdSI5jSu-S7cdFYzt9m_vIQ,5180
|
52
|
+
liger_kernel/transformers/model/mixtral.py,sha256=jpZJkpl625Q-JHWarj2MqT5mRaSsiCtg0c9vVyvOdCY,11430
|
53
|
+
liger_kernel/transformers/model/mllama.py,sha256=qWexBdskuN3gPJvPUwt4J0nU675tGD6W7wxgRZ9Bifg,11145
|
54
|
+
liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UKU4uk8Up8pU,10292
|
55
|
+
liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
|
56
|
+
liger_kernel/transformers/model/qwen2_vl.py,sha256=rZg3nU3YgF6wkB1UJ0a9IACSIlVOSCyLltyqw951MQQ,8609
|
57
|
+
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
58
|
+
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
|
59
|
+
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
|
+
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/METADATA,sha256=diXsKJ9zCLk-w9SCZLWWx-xN0ZP8-W51KrgpISmaxn4,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD,,
|
@@ -1,66 +0,0 @@
|
|
1
|
-
liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
liger_kernel/env_report.py,sha256=ok9PMXtO-8uLj_feCJI4h9hz2NtolZ2AG_OJTW5qmo4,1823
|
3
|
-
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
|
-
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
|
-
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
|
-
liger_kernel/chunked_loss/cpo_loss.py,sha256=3PdSp1gju1u0ffFGpUufbZPIva8aI3SW1TfqkJOpw1g,3554
|
7
|
-
liger_kernel/chunked_loss/dpo_loss.py,sha256=jbTno1pKEc-HxAGFY3NSycBzdWyTacyRCzH3FhrMUMo,4383
|
8
|
-
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
|
-
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
|
10
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vvratrj8rba8NaGbO2ffbUfWMVEvDMxDCo6SI8nCtbo,16376
|
11
|
-
liger_kernel/chunked_loss/orpo_loss.py,sha256=xHsKjlCWQVew7_hhpyUp3a1wd0tdpgx-zQAezNjk3Q4,3532
|
12
|
-
liger_kernel/chunked_loss/simpo_loss.py,sha256=_5gXIkEAT0Kt_AufziQlYhBjzDJVSQVk7oSDHcrw1xw,3759
|
13
|
-
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
liger_kernel/ops/cross_entropy.py,sha256=3oPrw6KzIVc11gSyfdrLnj0WJB4qOYjE1tC8HJeFFpg,15888
|
15
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=Tnw4gyAYVVdnCOqhOuLEzbUQ3goOTnoAfk3pqSIM5ac,9301
|
16
|
-
liger_kernel/ops/fused_linear_jsd.py,sha256=nOv4zwfxHqqepKEmMsQuz-B3H-gRjyo8uClpmqSGLYA,9693
|
17
|
-
liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
|
18
|
-
liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d8,10961
|
19
|
-
liger_kernel/ops/jsd.py,sha256=Ap2b0_geCl6fqBXLI1IS6Yn6GlO-8LgPmnOW3y47dus,6151
|
20
|
-
liger_kernel/ops/kl_div.py,sha256=vBz1ieu_sPcFbgG_wL0SwrbSQ6xVDK51_FNo-yf7CjY,8430
|
21
|
-
liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
|
22
|
-
liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
|
23
|
-
liger_kernel/ops/rms_norm.py,sha256=bleuRC9IS_P3zEX07b0LZ_cpgeTH8l5sdvkelucpRgM,11792
|
24
|
-
liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
|
25
|
-
liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
|
26
|
-
liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
|
27
|
-
liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
|
28
|
-
liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
|
29
|
-
liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
|
30
|
-
liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
|
31
|
-
liger_kernel/transformers/cross_entropy.py,sha256=yEm_YQ7oa3_BzT3hdW6KrAslduhSqWcJQVNZZDcWCg4,1758
|
32
|
-
liger_kernel/transformers/functional.py,sha256=sUBoU8Vb4pLpr9G6IdkRsToYgh-rCXL4OLYat7Tv_GU,4450
|
33
|
-
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=_i0PXSp5iZ9pKXdEeZ4lvHCENJYjV4y74yz3ZRG5XQg,1484
|
34
|
-
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
35
|
-
liger_kernel/transformers/geglu.py,sha256=QcrME_8ooIn0xa59LaC0aoOdRrBIFd11Y0bAyF0NfCw,1130
|
36
|
-
liger_kernel/transformers/group_norm.py,sha256=FJ9R7mS9G1wO-GRIQ6QKSmIhnZ6nQ6GIkE4NnX_hnn0,2241
|
37
|
-
liger_kernel/transformers/jsd.py,sha256=sbr8DnKSYZJH9pv2rpmboNijYGpZKbhb2-WSGp5_v6g,3001
|
38
|
-
liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
|
39
|
-
liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
|
40
|
-
liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
|
41
|
-
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
42
|
-
liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
|
43
|
-
liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
|
44
|
-
liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
|
45
|
-
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
46
|
-
liger_kernel/transformers/experimental/embedding.py,sha256=HpckiAMKM8-SRxKDcGTqortVxnjhwpZsfsp9lfjqfeM,895
|
47
|
-
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
|
-
liger_kernel/transformers/model/gemma.py,sha256=R4huxuR48gkLrdT8KqV7As2v9dZtEmcGVz6YG1ZmuJE,9692
|
49
|
-
liger_kernel/transformers/model/gemma2.py,sha256=zxQsxCRqkoxCES3GJPVI7soUuF3J5HZDlvJgaBos1zM,10836
|
50
|
-
liger_kernel/transformers/model/llama.py,sha256=RinsgC_eR-YNvZd2SHPQxZ4eyR3uViaTFCM3SvI5nks,10426
|
51
|
-
liger_kernel/transformers/model/mistral.py,sha256=XpL1rlWg_llvW3z_Hf_d8WQs7uQaH4ds7EZ2SxjQHsU,5144
|
52
|
-
liger_kernel/transformers/model/mixtral.py,sha256=JlNS6DA6SJqeHDk7j2LZymPQ3wngrTIo3wUGFBqHuJs,11504
|
53
|
-
liger_kernel/transformers/model/mllama.py,sha256=mesNCgj0Ea1O-fqRD4LVxDJ1CR2abY_zAzK_bfVzkiU,11222
|
54
|
-
liger_kernel/transformers/model/phi3.py,sha256=xUZPlaPKwknLjHc3uUW3EPodm1h0vD3G7Qnhh51v-Io,10332
|
55
|
-
liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5PBO3q0MoCs00,9619
|
56
|
-
liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
|
57
|
-
liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBbzGWILfaowUR1hmRw,210
|
58
|
-
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=O2k2vdHl-O1S-U61aEmyUFu3QrEuNAipQa2oUBb3HAA,7679
|
59
|
-
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
60
|
-
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
61
|
-
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
-
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/METADATA,sha256=glSPMysElXhTUr1u74GrG_xjFSIek9GtE9AlPR6GkLs,21055
|
63
|
-
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
-
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
-
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
-
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|