liger-kernel-nightly 0.5.3.dev20250221003838__py3-none-any.whl → 0.5.3.dev20250221011147__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -9,6 +9,7 @@ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa:
9
9
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
10
10
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
11
11
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
12
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
12
13
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
13
14
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
14
15
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
@@ -61,6 +61,85 @@ def _patch_layer_norm_module(module, eps=1e-6):
61
61
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
62
62
 
63
63
 
64
+ def apply_liger_kernel_to_granite(
65
+ rope: bool = True,
66
+ cross_entropy: bool = True,
67
+ fused_linear_cross_entropy: bool = False,
68
+ rms_norm: bool = True,
69
+ swiglu: bool = True,
70
+ model: PreTrainedModel = None,
71
+ ) -> None:
72
+ """
73
+ Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
74
+
75
+ Args:
76
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
77
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
78
+ fused_linear_cross_entropy (bool):
79
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
80
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
81
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
82
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
83
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
84
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
85
+ loaded. Default is None.
86
+
87
+
88
+
89
+ Debugging notes:
90
+ If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
91
+ """
92
+
93
+ assert not (cross_entropy and fused_linear_cross_entropy), (
94
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
95
+ )
96
+
97
+ from transformers.models.granite import modeling_granite
98
+ from transformers.models.granite.modeling_granite import GraniteModel
99
+
100
+ if swiglu:
101
+ modeling_granite.GraniteMLP = LigerSwiGLUMLP
102
+
103
+ if rms_norm:
104
+ modeling_granite.GraniteRMSNorm = LigerRMSNorm
105
+
106
+ if rope:
107
+ modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
108
+
109
+ if cross_entropy:
110
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
111
+ from transformers.loss.loss_utils import nn
112
+
113
+ nn.functional.cross_entropy = liger_cross_entropy
114
+ else:
115
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
116
+ modeling_granite.CrossEntropyLoss = LigerCrossEntropyLoss
117
+
118
+ if fused_linear_cross_entropy:
119
+ raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
120
+ # NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
121
+ # call, so we can't sidestep logit materialization. A bit more work
122
+ # would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
123
+ # for the logit output.
124
+
125
+ if model is not None:
126
+ # The model instance already exists, so we need to additionally patch the
127
+ # instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
128
+
129
+ # get the base model from the model instance
130
+ base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
131
+
132
+ if rms_norm:
133
+ _patch_rms_norm_module(base_model.norm)
134
+
135
+ for decoder_layer in base_model.layers:
136
+ if swiglu:
137
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
138
+ if rms_norm:
139
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
140
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
141
+
142
+
64
143
  def apply_liger_kernel_to_llama(
65
144
  rope: bool = True,
66
145
  cross_entropy: bool = False,
@@ -740,6 +819,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
740
819
  "gemma": apply_liger_kernel_to_gemma,
741
820
  "gemma2": apply_liger_kernel_to_gemma2,
742
821
  "llama": apply_liger_kernel_to_llama,
822
+ "granite": apply_liger_kernel_to_granite,
743
823
  "mllama": apply_liger_kernel_to_mllama,
744
824
  "mllama_text_model": apply_liger_kernel_to_mllama,
745
825
  "mistral": apply_liger_kernel_to_mistral,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221003838
3
+ Version: 0.5.3.dev20250221011147
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -32,7 +32,7 @@ liger_kernel/ops/tvd.py,sha256=9wVCijj2vBtgiLeUHhl7hy_LAiJ3liPIYOGMSU3P1ro,6407
32
32
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
33
33
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
34
34
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
35
- liger_kernel/transformers/__init__.py,sha256=VZI9hiCvvA371jsfkJmSt1CNXlBztIvlVGDExyKeqBM,2077
35
+ liger_kernel/transformers/__init__.py,sha256=i6GPkP5-esFBh205nF4MluNrL7KNugseGiUKdSHGW70,2172
36
36
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
37
37
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
38
38
  liger_kernel/transformers/functional.py,sha256=zahXVCjA2NxcVFpAgajILIRN0GO6mrbfLPgONUkTrY8,4940
@@ -43,7 +43,7 @@ liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD
43
43
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
44
44
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
45
45
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
46
- liger_kernel/transformers/monkey_patch.py,sha256=DXU00zsQvSjAqCx7l36gKm1O81FuHgILkZMhyx4ZSys,37812
46
+ liger_kernel/transformers/monkey_patch.py,sha256=kSJE1aMLN0e4jxiePUISRPcyvhWyzhOaqIwLW6rG0Zo,41191
47
47
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
48
48
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
49
49
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -65,9 +65,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
65
65
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
66
66
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
67
67
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
68
- liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
69
- liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/METADATA,sha256=zJp1YMQDbbOzeNRKFf7AN7hYqePdT49OEMQkN_buKl8,21963
70
- liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
71
- liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
72
- liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
73
- liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/RECORD,,
68
+ liger_kernel_nightly-0.5.3.dev20250221011147.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
69
+ liger_kernel_nightly-0.5.3.dev20250221011147.dist-info/METADATA,sha256=TuBMogIh5RNr68Eksjwz1MQukdz1w_HLF76Ryh2a06M,21963
70
+ liger_kernel_nightly-0.5.3.dev20250221011147.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
71
+ liger_kernel_nightly-0.5.3.dev20250221011147.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
72
+ liger_kernel_nightly-0.5.3.dev20250221011147.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
73
+ liger_kernel_nightly-0.5.3.dev20250221011147.dist-info/RECORD,,