liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241228022953__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +5 -12
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +12 -17
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +8 -24
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
|
3
4
|
from functools import partial
|
4
5
|
from typing import Callable
|
5
6
|
|
6
7
|
import transformers
|
8
|
+
|
7
9
|
from packaging import version
|
8
10
|
from transformers import PreTrainedModel
|
9
11
|
|
@@ -12,38 +14,24 @@ from liger_kernel.transformers.functional import liger_cross_entropy
|
|
12
14
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
13
15
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
14
16
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
15
|
-
from liger_kernel.transformers.model.gemma import
|
16
|
-
lce_forward_deprecated as gemma_lce_forward_deprecated,
|
17
|
-
)
|
17
|
+
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
18
18
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
19
|
-
from liger_kernel.transformers.model.gemma2 import
|
20
|
-
lce_forward_deprecated as gemma2_lce_forward_deprected,
|
21
|
-
)
|
19
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
22
20
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
23
|
-
from liger_kernel.transformers.model.llama import
|
24
|
-
lce_forward_deprecated as llama_lce_forward_deprecated,
|
25
|
-
)
|
21
|
+
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
26
22
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
27
23
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
28
|
-
from liger_kernel.transformers.model.mixtral import
|
29
|
-
lce_forward_deprecated as mixtral_lce_forward_deprecated,
|
30
|
-
)
|
24
|
+
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
31
25
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
32
|
-
from liger_kernel.transformers.model.phi3 import
|
33
|
-
lce_forward_deprecated as phi3_lce_forward_deprecated,
|
34
|
-
)
|
26
|
+
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
35
27
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
36
|
-
from liger_kernel.transformers.model.qwen2 import
|
37
|
-
lce_forward_deprecated as qwen2_lce_forward_deprecated,
|
38
|
-
)
|
28
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
39
29
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
40
30
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
41
31
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
42
|
-
from liger_kernel.transformers.swiglu import
|
43
|
-
|
44
|
-
|
45
|
-
LigerSwiGLUMLP,
|
46
|
-
)
|
32
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
33
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
34
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
47
35
|
|
48
36
|
transformer_version = version.parse(transformers.__version__)
|
49
37
|
|
@@ -57,23 +45,17 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
57
45
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
58
46
|
|
59
47
|
|
60
|
-
def _patch_rms_norm_module(
|
61
|
-
module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
|
62
|
-
):
|
48
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
63
49
|
module.offset = offset
|
64
50
|
module.casting_mode = casting_mode
|
65
|
-
module.variance_epsilon = (
|
66
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
67
|
-
)
|
51
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
68
52
|
module.in_place = in_place
|
69
53
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
70
54
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
71
55
|
|
72
56
|
|
73
57
|
def _patch_layer_norm_module(module, eps=1e-6):
|
74
|
-
module.variance_epsilon = (
|
75
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
76
|
-
)
|
58
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
77
59
|
module.hidden_size = module.normalized_shape
|
78
60
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
79
61
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
@@ -145,9 +127,7 @@ def apply_liger_kernel_to_llama(
|
|
145
127
|
|
146
128
|
for decoder_layer in base_model.layers:
|
147
129
|
if swiglu:
|
148
|
-
_bind_method_to_module(
|
149
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
150
|
-
)
|
130
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
151
131
|
if rms_norm:
|
152
132
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
153
133
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -184,17 +164,13 @@ def apply_liger_kernel_to_mllama(
|
|
184
164
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
185
165
|
|
186
166
|
from transformers.models.mllama import modeling_mllama
|
187
|
-
from transformers.models.mllama.modeling_mllama import
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
MllamaVisionModel,
|
192
|
-
)
|
167
|
+
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
|
168
|
+
from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
|
169
|
+
from transformers.models.mllama.modeling_mllama import MllamaTextModel
|
170
|
+
from transformers.models.mllama.modeling_mllama import MllamaVisionModel
|
193
171
|
|
194
172
|
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
|
195
|
-
from liger_kernel.transformers.model.mllama import
|
196
|
-
lce_forward_deprecated as mllama_lce_forward_deprecated,
|
197
|
-
)
|
173
|
+
from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
|
198
174
|
|
199
175
|
if rope:
|
200
176
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
@@ -241,9 +217,7 @@ def apply_liger_kernel_to_mllama(
|
|
241
217
|
_patch_rms_norm_module(text_model.norm)
|
242
218
|
for decoder_layer in text_model.layers:
|
243
219
|
if swiglu:
|
244
|
-
_bind_method_to_module(
|
245
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
246
|
-
)
|
220
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
247
221
|
if rms_norm:
|
248
222
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
249
223
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -317,9 +291,7 @@ def apply_liger_kernel_to_mistral(
|
|
317
291
|
|
318
292
|
for decoder_layer in base_model.layers:
|
319
293
|
if swiglu:
|
320
|
-
_bind_method_to_module(
|
321
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
322
|
-
)
|
294
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
323
295
|
if rms_norm:
|
324
296
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
325
297
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -391,9 +363,7 @@ def apply_liger_kernel_to_mixtral(
|
|
391
363
|
for decoder_layer in base_model.layers:
|
392
364
|
if swiglu:
|
393
365
|
for expert in decoder_layer.block_sparse_moe.experts:
|
394
|
-
_bind_method_to_module(
|
395
|
-
expert, "forward", LigerBlockSparseTop2MLP.forward
|
396
|
-
)
|
366
|
+
_bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
|
397
367
|
if rms_norm:
|
398
368
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
399
369
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -431,12 +401,8 @@ def apply_liger_kernel_to_gemma(
|
|
431
401
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
432
402
|
|
433
403
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
434
|
-
LigerRMSNormForGemma = partial(
|
435
|
-
|
436
|
-
)
|
437
|
-
_patch_rms_norm_module_for_gemma = partial(
|
438
|
-
_patch_rms_norm_module, casting_mode="gemma", offset=1.0
|
439
|
-
)
|
404
|
+
LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
|
405
|
+
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
440
406
|
|
441
407
|
if rope:
|
442
408
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
@@ -471,9 +437,7 @@ def apply_liger_kernel_to_gemma(
|
|
471
437
|
|
472
438
|
for decoder_layer in base_model.layers:
|
473
439
|
if geglu:
|
474
|
-
_bind_method_to_module(
|
475
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
476
|
-
)
|
440
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
477
441
|
if rms_norm:
|
478
442
|
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
479
443
|
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
@@ -510,9 +474,7 @@ def apply_liger_kernel_to_gemma2(
|
|
510
474
|
from transformers.models.gemma2 import modeling_gemma2
|
511
475
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
512
476
|
|
513
|
-
LigerRMSNormForGemma2 = partial(
|
514
|
-
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
|
515
|
-
)
|
477
|
+
LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
|
516
478
|
_patch_rms_norm_module_for_gemma2 = partial(
|
517
479
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
518
480
|
)
|
@@ -551,20 +513,12 @@ def apply_liger_kernel_to_gemma2(
|
|
551
513
|
|
552
514
|
for decoder_layer in base_model.layers:
|
553
515
|
if geglu:
|
554
|
-
_bind_method_to_module(
|
555
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
556
|
-
)
|
516
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
557
517
|
if rms_norm:
|
558
518
|
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
559
|
-
_patch_rms_norm_module_for_gemma2(
|
560
|
-
|
561
|
-
)
|
562
|
-
_patch_rms_norm_module_for_gemma2(
|
563
|
-
decoder_layer.pre_feedforward_layernorm
|
564
|
-
)
|
565
|
-
_patch_rms_norm_module_for_gemma2(
|
566
|
-
decoder_layer.post_feedforward_layernorm
|
567
|
-
)
|
519
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
520
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
|
521
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
568
522
|
|
569
523
|
|
570
524
|
def apply_liger_kernel_to_qwen2(
|
@@ -633,9 +587,7 @@ def apply_liger_kernel_to_qwen2(
|
|
633
587
|
|
634
588
|
for decoder_layer in base_model.layers:
|
635
589
|
if swiglu:
|
636
|
-
_bind_method_to_module(
|
637
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
638
|
-
)
|
590
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
639
591
|
if rms_norm:
|
640
592
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
641
593
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -674,14 +626,10 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
674
626
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
675
627
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
676
628
|
|
677
|
-
from liger_kernel.transformers.model.qwen2_vl import
|
678
|
-
lce_forward as qwen2_vl_lce_forward,
|
679
|
-
)
|
629
|
+
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
680
630
|
|
681
631
|
if rope:
|
682
|
-
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb =
|
683
|
-
liger_multimodal_rotary_pos_emb
|
684
|
-
)
|
632
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
685
633
|
if rms_norm:
|
686
634
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
687
635
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
@@ -712,9 +660,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
712
660
|
_patch_rms_norm_module(base_model.norm)
|
713
661
|
for decoder_layer in base_model.layers:
|
714
662
|
if swiglu:
|
715
|
-
_bind_method_to_module(
|
716
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
717
|
-
)
|
663
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
718
664
|
if rms_norm:
|
719
665
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
720
666
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -783,9 +729,7 @@ def apply_liger_kernel_to_phi3(
|
|
783
729
|
|
784
730
|
for decoder_layer in base_model.layers:
|
785
731
|
if swiglu:
|
786
|
-
_bind_method_to_module(
|
787
|
-
decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
|
788
|
-
)
|
732
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
|
789
733
|
if rms_norm:
|
790
734
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
791
735
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -826,24 +770,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
|
|
826
770
|
return
|
827
771
|
|
828
772
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
829
|
-
logger.info(
|
830
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
831
|
-
)
|
773
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
832
774
|
return
|
833
775
|
|
834
776
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
835
777
|
apply_fn_signature = inspect.signature(apply_fn)
|
836
778
|
|
837
779
|
# Filter out the keyword arguments that are not supported by the apply function
|
838
|
-
applicable_kwargs = {
|
839
|
-
key: value
|
840
|
-
for key, value in kwargs.items()
|
841
|
-
if key in apply_fn_signature.parameters
|
842
|
-
}
|
780
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
843
781
|
|
844
|
-
logger.info(
|
845
|
-
f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
|
846
|
-
)
|
782
|
+
logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
|
847
783
|
|
848
784
|
# Assume this is invoked pre-model initialization, so we only need to patch transformers code
|
849
785
|
apply_fn(**applicable_kwargs)
|
@@ -857,20 +793,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
857
793
|
- model: the model instance to apply Liger kernels to
|
858
794
|
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
859
795
|
"""
|
860
|
-
model_type = getattr(model, "config", None) and getattr(
|
861
|
-
model.config, "model_type", None
|
862
|
-
)
|
796
|
+
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
863
797
|
|
864
798
|
if not model_type:
|
865
|
-
logger.info(
|
866
|
-
"Model type could not be determined from model config. No Liger kernels will be applied."
|
867
|
-
)
|
799
|
+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
|
868
800
|
return
|
869
801
|
|
870
802
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
871
|
-
logger.info(
|
872
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
873
|
-
)
|
803
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
874
804
|
return
|
875
805
|
|
876
806
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
@@ -878,11 +808,7 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
878
808
|
apply_fn_signature = inspect.signature(apply_fn)
|
879
809
|
|
880
810
|
# Filter out the keyword arguments that are not supported by the apply function
|
881
|
-
applicable_kwargs = {
|
882
|
-
key: value
|
883
|
-
for key, value in kwargs.items()
|
884
|
-
if key in apply_fn_signature.parameters
|
885
|
-
}
|
811
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
886
812
|
logger.info(
|
887
813
|
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
888
814
|
)
|
@@ -19,9 +19,7 @@ class LigerRMSNorm(nn.Module):
|
|
19
19
|
"ones",
|
20
20
|
"zeros",
|
21
21
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
22
|
-
self.weight = nn.Parameter(
|
23
|
-
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
24
|
-
)
|
22
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
25
23
|
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
|
26
24
|
eps,
|
27
25
|
offset,
|
@@ -40,4 +38,6 @@ class LigerRMSNorm(nn.Module):
|
|
40
38
|
)
|
41
39
|
|
42
40
|
def extra_repr(self):
|
43
|
-
return
|
41
|
+
return (
|
42
|
+
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
|
43
|
+
)
|
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
|
|
16
16
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
17
17
|
|
18
18
|
def forward(self, x):
|
19
|
-
|
20
|
-
return self.down_proj(
|
21
|
-
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
22
|
-
)
|
19
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
23
20
|
|
24
21
|
|
25
22
|
class LigerBlockSparseTop2MLP(nn.Module):
|
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
|
|
36
33
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
37
34
|
|
38
35
|
def forward(self, x):
|
39
|
-
|
40
36
|
return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
|
41
37
|
|
42
38
|
|
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
51
47
|
self.config = config
|
52
48
|
self.hidden_size = config.hidden_size
|
53
49
|
self.intermediate_size = config.intermediate_size
|
54
|
-
self.gate_up_proj = nn.Linear(
|
55
|
-
self.hidden_size, 2 * self.intermediate_size, bias=False
|
56
|
-
)
|
50
|
+
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
57
51
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
58
52
|
if config.hidden_act not in ["silu", "swish"]:
|
59
53
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
@@ -1,6 +1,4 @@
|
|
1
1
|
try:
|
2
|
-
from liger_kernel.transformers.trainer.orpo_trainer import
|
3
|
-
LigerORPOTrainer,
|
4
|
-
)
|
2
|
+
from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
|
5
3
|
except ImportError:
|
6
4
|
raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
|
@@ -1,7 +1,14 @@
|
|
1
|
-
from typing import Any
|
1
|
+
from typing import Any
|
2
|
+
from typing import Callable
|
3
|
+
from typing import Dict
|
4
|
+
from typing import List
|
5
|
+
from typing import Literal
|
6
|
+
from typing import Tuple
|
7
|
+
from typing import Union
|
2
8
|
|
3
9
|
import torch
|
4
10
|
import torch.nn as nn
|
11
|
+
|
5
12
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
6
13
|
from trl.trainer import ORPOTrainer
|
7
14
|
|
@@ -62,9 +69,7 @@ class _FSDPForwardRedirection:
|
|
62
69
|
class LigerORPOTrainer(ORPOTrainer):
|
63
70
|
def concatenated_forward(
|
64
71
|
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
65
|
-
) -> Tuple[
|
66
|
-
torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
|
67
|
-
]:
|
72
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
68
73
|
"""
|
69
74
|
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
70
75
|
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
@@ -79,9 +84,7 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
79
84
|
|
80
85
|
model_kwargs = (
|
81
86
|
{
|
82
|
-
"decoder_input_ids": self._shift_right(
|
83
|
-
concatenated_batch["concatenated_labels"]
|
84
|
-
),
|
87
|
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
85
88
|
}
|
86
89
|
if self.is_encoder_decoder
|
87
90
|
else {}
|
@@ -109,14 +112,10 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
109
112
|
**model_kwargs,
|
110
113
|
)
|
111
114
|
|
112
|
-
orpo_loss_fn = LigerFusedLinearORPOLoss(
|
113
|
-
ignore_index=self.label_pad_token_id, beta=self.beta
|
114
|
-
)
|
115
|
+
orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
|
115
116
|
|
116
117
|
def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
|
117
|
-
return orpo_loss_fn(
|
118
|
-
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias
|
119
|
-
)
|
118
|
+
return orpo_loss_fn(lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias)
|
120
119
|
|
121
120
|
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
|
122
121
|
model,
|
@@ -149,9 +148,7 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
149
148
|
) = aux_outputs[:5]
|
150
149
|
|
151
150
|
# return loss, metrics
|
152
|
-
chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[
|
153
|
-
5:
|
154
|
-
]
|
151
|
+
chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[5:]
|
155
152
|
|
156
153
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
157
154
|
|
liger_kernel/triton/__init__.py
CHANGED
@@ -37,6 +37,4 @@ def apply_liger_triton_cache_manager():
|
|
37
37
|
Experimental feature to get around transient FileNotFoundError in triton compilation.
|
38
38
|
For more details please see https://github.com/triton-lang/triton/pull/4295
|
39
39
|
"""
|
40
|
-
os.environ["TRITON_CACHE_MANAGER"] =
|
41
|
-
"liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
|
42
|
-
)
|
40
|
+
os.environ["TRITON_CACHE_MANAGER"] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
|
@@ -0,0 +1,66 @@
|
|
1
|
+
liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
|
3
|
+
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
|
+
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
|
+
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
|
+
liger_kernel/chunked_loss/cpo_loss.py,sha256=L4Nk38Xh5Yfhah3Vsc_sN_Q75FWt1LA-xNNXzsK8iPM,3516
|
7
|
+
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
8
|
+
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
|
+
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=M-QWvGPnWefYDn6Hr9bPn7diMNP5qrUaeWTb_zdMO4E,10265
|
10
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
|
11
|
+
liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
|
12
|
+
liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
|
13
|
+
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
liger_kernel/ops/cross_entropy.py,sha256=2OPIkSXeQAIfSCODYK45Jf8xrz7HoGqFHr1MHS_pijE,15895
|
15
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=LR0zLL8JYMhk9e22jmBxU4lwEYic3YqMAG3837yaHmM,9418
|
16
|
+
liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
|
17
|
+
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
18
|
+
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
19
|
+
liger_kernel/ops/jsd.py,sha256=WwGY9ozuH3PMg3udRI6H96UqAEzIozJoO2HtHg7010M,6107
|
20
|
+
liger_kernel/ops/kl_div.py,sha256=MnfuYqqQESON1X2Swy064x1urKtMFdgeSWd60VttBXI,8420
|
21
|
+
liger_kernel/ops/layer_norm.py,sha256=quvt2zcwcJCDxrgm-iWoHzDYOoeZdMC76nZ_ckw6-p8,7640
|
22
|
+
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
23
|
+
liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
|
24
|
+
liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
|
25
|
+
liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
|
26
|
+
liger_kernel/ops/utils.py,sha256=vMWxfcw02xUvjpEXQQ3Rrj68ddZ8Of3hiOmEFq1zSKg,3852
|
27
|
+
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
28
|
+
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
29
|
+
liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
|
30
|
+
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
31
|
+
liger_kernel/transformers/cross_entropy.py,sha256=s5-ZM1NBMDjG-KKJKBtIkmArj1jCUjDnpL-2QKhKYho,1734
|
32
|
+
liger_kernel/transformers/functional.py,sha256=hxReSBDEUZkOnZgURD8sf6ETYvf9yqCOOMU2k9Ywh90,4435
|
33
|
+
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=K4tfpoNPUJpWv7rCHEcs5xhJLg5td8GcpJrAryF5NMk,1451
|
34
|
+
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
35
|
+
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
36
|
+
liger_kernel/transformers/group_norm.py,sha256=URmjkQFsrbMffzcJiGpX7ckxWlpL95AiJS-80hwAWPk,2173
|
37
|
+
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
38
|
+
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
39
|
+
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
40
|
+
liger_kernel/transformers/monkey_patch.py,sha256=6eXmtERKr4YUppRAaH7a_ml3AOz0ao68E8QnOyXtIkY,37794
|
41
|
+
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
42
|
+
liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
|
43
|
+
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
44
|
+
liger_kernel/transformers/swiglu.py,sha256=i9WTqcNRqReU4XJs391IPbl-I5X0wG4T72D4pqGFfJg,2422
|
45
|
+
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
46
|
+
liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
|
47
|
+
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
|
+
liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
|
49
|
+
liger_kernel/transformers/model/gemma2.py,sha256=27NcoZjEqP9Lqb4Wf0EKqTbr2HTGiHPhrVyPCRGPz6s,10767
|
50
|
+
liger_kernel/transformers/model/llama.py,sha256=3LJFXKFDKvEakaWPc_NicSFst4Y_hdSMrdl1UDK1EcA,10330
|
51
|
+
liger_kernel/transformers/model/mistral.py,sha256=MVRksI5_j_8WJu8znOHKCdSI5jSu-S7cdFYzt9m_vIQ,5180
|
52
|
+
liger_kernel/transformers/model/mixtral.py,sha256=jpZJkpl625Q-JHWarj2MqT5mRaSsiCtg0c9vVyvOdCY,11430
|
53
|
+
liger_kernel/transformers/model/mllama.py,sha256=qWexBdskuN3gPJvPUwt4J0nU675tGD6W7wxgRZ9Bifg,11145
|
54
|
+
liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UKU4uk8Up8pU,10292
|
55
|
+
liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
|
56
|
+
liger_kernel/transformers/model/qwen2_vl.py,sha256=rZg3nU3YgF6wkB1UJ0a9IACSIlVOSCyLltyqw951MQQ,8609
|
57
|
+
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
58
|
+
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
|
59
|
+
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
|
+
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/METADATA,sha256=Z5fzI-xpYPtjwawEGwIw-LRJUIeY1VEdDUK9wgklR7w,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD,,
|
@@ -1,66 +0,0 @@
|
|
1
|
-
liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
liger_kernel/env_report.py,sha256=ok9PMXtO-8uLj_feCJI4h9hz2NtolZ2AG_OJTW5qmo4,1823
|
3
|
-
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
|
-
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
|
-
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
|
-
liger_kernel/chunked_loss/cpo_loss.py,sha256=3PdSp1gju1u0ffFGpUufbZPIva8aI3SW1TfqkJOpw1g,3554
|
7
|
-
liger_kernel/chunked_loss/dpo_loss.py,sha256=jbTno1pKEc-HxAGFY3NSycBzdWyTacyRCzH3FhrMUMo,4383
|
8
|
-
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
|
-
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
|
10
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vvratrj8rba8NaGbO2ffbUfWMVEvDMxDCo6SI8nCtbo,16376
|
11
|
-
liger_kernel/chunked_loss/orpo_loss.py,sha256=xHsKjlCWQVew7_hhpyUp3a1wd0tdpgx-zQAezNjk3Q4,3532
|
12
|
-
liger_kernel/chunked_loss/simpo_loss.py,sha256=_5gXIkEAT0Kt_AufziQlYhBjzDJVSQVk7oSDHcrw1xw,3759
|
13
|
-
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
liger_kernel/ops/cross_entropy.py,sha256=3oPrw6KzIVc11gSyfdrLnj0WJB4qOYjE1tC8HJeFFpg,15888
|
15
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=Tnw4gyAYVVdnCOqhOuLEzbUQ3goOTnoAfk3pqSIM5ac,9301
|
16
|
-
liger_kernel/ops/fused_linear_jsd.py,sha256=nOv4zwfxHqqepKEmMsQuz-B3H-gRjyo8uClpmqSGLYA,9693
|
17
|
-
liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
|
18
|
-
liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d8,10961
|
19
|
-
liger_kernel/ops/jsd.py,sha256=Ap2b0_geCl6fqBXLI1IS6Yn6GlO-8LgPmnOW3y47dus,6151
|
20
|
-
liger_kernel/ops/kl_div.py,sha256=vBz1ieu_sPcFbgG_wL0SwrbSQ6xVDK51_FNo-yf7CjY,8430
|
21
|
-
liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
|
22
|
-
liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
|
23
|
-
liger_kernel/ops/rms_norm.py,sha256=bleuRC9IS_P3zEX07b0LZ_cpgeTH8l5sdvkelucpRgM,11792
|
24
|
-
liger_kernel/ops/rope.py,sha256=KyjbI6ya6bDwmdBJKK1IamuTUMpAmfdsHFYRJ4d9cP8,9059
|
25
|
-
liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
|
26
|
-
liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
|
27
|
-
liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
|
28
|
-
liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
|
29
|
-
liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
|
30
|
-
liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
|
31
|
-
liger_kernel/transformers/cross_entropy.py,sha256=yEm_YQ7oa3_BzT3hdW6KrAslduhSqWcJQVNZZDcWCg4,1758
|
32
|
-
liger_kernel/transformers/functional.py,sha256=sUBoU8Vb4pLpr9G6IdkRsToYgh-rCXL4OLYat7Tv_GU,4450
|
33
|
-
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=_i0PXSp5iZ9pKXdEeZ4lvHCENJYjV4y74yz3ZRG5XQg,1484
|
34
|
-
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
35
|
-
liger_kernel/transformers/geglu.py,sha256=QcrME_8ooIn0xa59LaC0aoOdRrBIFd11Y0bAyF0NfCw,1130
|
36
|
-
liger_kernel/transformers/group_norm.py,sha256=FJ9R7mS9G1wO-GRIQ6QKSmIhnZ6nQ6GIkE4NnX_hnn0,2241
|
37
|
-
liger_kernel/transformers/jsd.py,sha256=sbr8DnKSYZJH9pv2rpmboNijYGpZKbhb2-WSGp5_v6g,3001
|
38
|
-
liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
|
39
|
-
liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
|
40
|
-
liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
|
41
|
-
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
42
|
-
liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
|
43
|
-
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
44
|
-
liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
|
45
|
-
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
46
|
-
liger_kernel/transformers/experimental/embedding.py,sha256=HpckiAMKM8-SRxKDcGTqortVxnjhwpZsfsp9lfjqfeM,895
|
47
|
-
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
|
-
liger_kernel/transformers/model/gemma.py,sha256=R4huxuR48gkLrdT8KqV7As2v9dZtEmcGVz6YG1ZmuJE,9692
|
49
|
-
liger_kernel/transformers/model/gemma2.py,sha256=zxQsxCRqkoxCES3GJPVI7soUuF3J5HZDlvJgaBos1zM,10836
|
50
|
-
liger_kernel/transformers/model/llama.py,sha256=RinsgC_eR-YNvZd2SHPQxZ4eyR3uViaTFCM3SvI5nks,10426
|
51
|
-
liger_kernel/transformers/model/mistral.py,sha256=XpL1rlWg_llvW3z_Hf_d8WQs7uQaH4ds7EZ2SxjQHsU,5144
|
52
|
-
liger_kernel/transformers/model/mixtral.py,sha256=JlNS6DA6SJqeHDk7j2LZymPQ3wngrTIo3wUGFBqHuJs,11504
|
53
|
-
liger_kernel/transformers/model/mllama.py,sha256=mesNCgj0Ea1O-fqRD4LVxDJ1CR2abY_zAzK_bfVzkiU,11222
|
54
|
-
liger_kernel/transformers/model/phi3.py,sha256=xUZPlaPKwknLjHc3uUW3EPodm1h0vD3G7Qnhh51v-Io,10332
|
55
|
-
liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5PBO3q0MoCs00,9619
|
56
|
-
liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
|
57
|
-
liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBbzGWILfaowUR1hmRw,210
|
58
|
-
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=O2k2vdHl-O1S-U61aEmyUFu3QrEuNAipQa2oUBb3HAA,7679
|
59
|
-
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
60
|
-
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
61
|
-
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
-
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/METADATA,sha256=rY2y3vkXwGKfZpmRsIIbD9BwAVpeYe6wbVwKJbMWB8k,21055
|
63
|
-
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
-
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
-
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
-
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|