liger-kernel 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +2 -0
- liger_kernel/ops/cross_entropy.py +144 -65
- liger_kernel/ops/experimental/mm_int8int2.py +355 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +31 -11
- liger_kernel/ops/fused_linear_jsd.py +245 -0
- liger_kernel/ops/geglu.py +2 -2
- liger_kernel/ops/group_norm.py +322 -0
- liger_kernel/ops/jsd.py +176 -0
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/ops/rms_norm.py +92 -46
- liger_kernel/ops/swiglu.py +2 -2
- liger_kernel/ops/utils.py +62 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/cross_entropy.py +44 -12
- liger_kernel/transformers/functional.py +38 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
- liger_kernel/transformers/fused_linear_jsd.py +98 -0
- liger_kernel/transformers/group_norm.py +56 -0
- liger_kernel/transformers/jsd.py +75 -0
- liger_kernel/transformers/model/gemma.py +124 -1
- liger_kernel/transformers/model/gemma2.py +277 -0
- liger_kernel/transformers/model/llama.py +135 -4
- liger_kernel/transformers/model/mistral.py +3 -0
- liger_kernel/transformers/model/mixtral.py +153 -2
- liger_kernel/transformers/model/mllama.py +274 -0
- liger_kernel/transformers/model/phi3.py +140 -2
- liger_kernel/transformers/model/qwen2.py +123 -2
- liger_kernel/transformers/model/qwen2_vl.py +8 -1
- liger_kernel/transformers/monkey_patch.py +258 -68
- liger_kernel/transformers/rms_norm.py +11 -3
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +63 -29
- liger_kernel-0.4.1.dist-info/NOTICE +58 -0
- liger_kernel-0.4.1.dist-info/RECORD +51 -0
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
- liger_kernel-0.3.1.dist-info/NOTICE +0 -4
- liger_kernel-0.3.1.dist-info/RECORD +0 -42
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -3,17 +3,39 @@ import logging
|
|
|
3
3
|
from functools import partial
|
|
4
4
|
from typing import Callable
|
|
5
5
|
|
|
6
|
+
import transformers
|
|
7
|
+
from packaging import version
|
|
6
8
|
from transformers import PreTrainedModel
|
|
7
9
|
|
|
8
10
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
11
|
+
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
9
12
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
10
13
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
11
14
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
15
|
+
from liger_kernel.transformers.model.gemma import (
|
|
16
|
+
lce_forward_deprecated as gemma_lce_forward_deprecated,
|
|
17
|
+
)
|
|
18
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
19
|
+
from liger_kernel.transformers.model.gemma2 import (
|
|
20
|
+
lce_forward_deprecated as gemma2_lce_forward_deprected,
|
|
21
|
+
)
|
|
12
22
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
23
|
+
from liger_kernel.transformers.model.llama import (
|
|
24
|
+
lce_forward_deprecated as llama_lce_forward_deprecated,
|
|
25
|
+
)
|
|
13
26
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
|
14
27
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
28
|
+
from liger_kernel.transformers.model.mixtral import (
|
|
29
|
+
lce_forward_deprecated as mixtral_lce_forward_deprecated,
|
|
30
|
+
)
|
|
15
31
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
32
|
+
from liger_kernel.transformers.model.phi3 import (
|
|
33
|
+
lce_forward_deprecated as phi3_lce_forward_deprecated,
|
|
34
|
+
)
|
|
16
35
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
36
|
+
from liger_kernel.transformers.model.qwen2 import (
|
|
37
|
+
lce_forward_deprecated as qwen2_lce_forward_deprecated,
|
|
38
|
+
)
|
|
17
39
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
18
40
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
19
41
|
from liger_kernel.transformers.swiglu import (
|
|
@@ -22,7 +44,11 @@ from liger_kernel.transformers.swiglu import (
|
|
|
22
44
|
LigerSwiGLUMLP,
|
|
23
45
|
)
|
|
24
46
|
|
|
47
|
+
transformer_version = version.parse(transformers.__version__)
|
|
48
|
+
|
|
25
49
|
logger = logging.getLogger(__name__)
|
|
50
|
+
SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
|
|
51
|
+
TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
|
|
26
52
|
|
|
27
53
|
|
|
28
54
|
def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
@@ -78,6 +104,7 @@ def apply_liger_kernel_to_llama(
|
|
|
78
104
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
79
105
|
|
|
80
106
|
from transformers.models.llama import modeling_llama
|
|
107
|
+
from transformers.models.llama.modeling_llama import LlamaModel
|
|
81
108
|
|
|
82
109
|
if rope:
|
|
83
110
|
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -85,24 +112,29 @@ def apply_liger_kernel_to_llama(
|
|
|
85
112
|
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
|
86
113
|
if swiglu:
|
|
87
114
|
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
|
115
|
+
|
|
88
116
|
if cross_entropy:
|
|
89
|
-
|
|
117
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
118
|
+
from transformers.loss.loss_utils import nn
|
|
119
|
+
|
|
120
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
121
|
+
else:
|
|
122
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
123
|
+
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
124
|
+
|
|
90
125
|
if fused_linear_cross_entropy:
|
|
91
|
-
|
|
126
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
127
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
128
|
+
else: # if version < 4.46.1
|
|
129
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
130
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
92
131
|
|
|
93
132
|
if model is not None:
|
|
94
133
|
# The model instance already exists, so we need to additionally patch the
|
|
95
134
|
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
96
135
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
base_model = model.model
|
|
100
|
-
elif hasattr(model, "transformer"):
|
|
101
|
-
# LlamaForQuestionAnswering uses "transformer" instead of "model"
|
|
102
|
-
base_model = model.transformer
|
|
103
|
-
else:
|
|
104
|
-
# Direct LlamaModel
|
|
105
|
-
base_model = model
|
|
136
|
+
# get the base model from the model instance
|
|
137
|
+
base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
|
|
106
138
|
|
|
107
139
|
if rms_norm:
|
|
108
140
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -117,6 +149,116 @@ def apply_liger_kernel_to_llama(
|
|
|
117
149
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
118
150
|
|
|
119
151
|
|
|
152
|
+
def apply_liger_kernel_to_mllama(
|
|
153
|
+
rope: bool = True,
|
|
154
|
+
cross_entropy: bool = False,
|
|
155
|
+
fused_linear_cross_entropy: bool = True,
|
|
156
|
+
layer_norm: bool = True,
|
|
157
|
+
rms_norm: bool = True,
|
|
158
|
+
swiglu: bool = True,
|
|
159
|
+
model: PreTrainedModel = None,
|
|
160
|
+
) -> None:
|
|
161
|
+
"""
|
|
162
|
+
Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
|
|
163
|
+
NOTE: MLlama is not available in transformers<4.45.0
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
167
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
168
|
+
fused_linear_cross_entropy (bool):
|
|
169
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
170
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
171
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
172
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
173
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
174
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
175
|
+
loaded. Default is None.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
assert not (
|
|
179
|
+
cross_entropy and fused_linear_cross_entropy
|
|
180
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
181
|
+
|
|
182
|
+
from transformers.models.mllama import modeling_mllama
|
|
183
|
+
from transformers.models.mllama.modeling_mllama import (
|
|
184
|
+
MllamaForCausalLM,
|
|
185
|
+
MllamaForConditionalGeneration,
|
|
186
|
+
MllamaTextModel,
|
|
187
|
+
MllamaVisionModel,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
|
|
191
|
+
from liger_kernel.transformers.model.mllama import (
|
|
192
|
+
lce_forward_deprecated as mllama_lce_forward_deprecated,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if rope:
|
|
196
|
+
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
197
|
+
if layer_norm:
|
|
198
|
+
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
199
|
+
if rms_norm:
|
|
200
|
+
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
201
|
+
if swiglu:
|
|
202
|
+
modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
|
|
203
|
+
if cross_entropy:
|
|
204
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
205
|
+
from transformers.loss.loss_utils import nn
|
|
206
|
+
|
|
207
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
208
|
+
else:
|
|
209
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
210
|
+
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
211
|
+
if fused_linear_cross_entropy:
|
|
212
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
213
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
214
|
+
else: # if version < 4.46.1
|
|
215
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
216
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
217
|
+
|
|
218
|
+
if model is not None:
|
|
219
|
+
# The model instance already exists, so we need to additionally patch the
|
|
220
|
+
# instance variables that reference already-instantiated modules
|
|
221
|
+
|
|
222
|
+
if isinstance(model, MllamaForConditionalGeneration):
|
|
223
|
+
language_model: MllamaForCausalLM = model.language_model
|
|
224
|
+
vision_model: MllamaVisionModel = model.vision_model
|
|
225
|
+
text_model: MllamaTextModel = language_model.model
|
|
226
|
+
elif isinstance(model, MllamaForCausalLM):
|
|
227
|
+
text_model = model.model
|
|
228
|
+
vision_model = None
|
|
229
|
+
elif isinstance(model, MllamaTextModel):
|
|
230
|
+
text_model = model
|
|
231
|
+
vision_model = None
|
|
232
|
+
else:
|
|
233
|
+
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
|
|
234
|
+
|
|
235
|
+
if text_model:
|
|
236
|
+
if rms_norm:
|
|
237
|
+
_patch_rms_norm_module(text_model.norm)
|
|
238
|
+
for decoder_layer in text_model.layers:
|
|
239
|
+
if swiglu:
|
|
240
|
+
_bind_method_to_module(
|
|
241
|
+
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
242
|
+
)
|
|
243
|
+
if rms_norm:
|
|
244
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
245
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
246
|
+
|
|
247
|
+
if vision_model:
|
|
248
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
249
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
250
|
+
|
|
251
|
+
for layer in vision_model.transformer.layers:
|
|
252
|
+
if layer_norm:
|
|
253
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
254
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
255
|
+
|
|
256
|
+
for layer in vision_model.global_transformer.layers:
|
|
257
|
+
if layer_norm:
|
|
258
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
259
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
260
|
+
|
|
261
|
+
|
|
120
262
|
def apply_liger_kernel_to_mistral(
|
|
121
263
|
rope: bool = True,
|
|
122
264
|
cross_entropy: bool = False,
|
|
@@ -129,7 +271,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
129
271
|
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
|
|
130
272
|
|
|
131
273
|
Args:
|
|
132
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is
|
|
274
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
133
275
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
134
276
|
fused_linear_cross_entropy (bool):
|
|
135
277
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
@@ -146,6 +288,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
146
288
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
147
289
|
|
|
148
290
|
from transformers.models.mistral import modeling_mistral
|
|
291
|
+
from transformers.models.mistral.modeling_mistral import MistralModel
|
|
149
292
|
|
|
150
293
|
if rope:
|
|
151
294
|
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -162,12 +305,8 @@ def apply_liger_kernel_to_mistral(
|
|
|
162
305
|
# The model instance already exists, so we need to additionally patch the
|
|
163
306
|
# instance variables that reference already-instantiated modules
|
|
164
307
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
base_model = model.model
|
|
168
|
-
else:
|
|
169
|
-
# Direct MistralModel
|
|
170
|
-
base_model = model
|
|
308
|
+
# get the base model from the model instance
|
|
309
|
+
base_model: MistralModel = getattr(model, model.base_model_prefix, model)
|
|
171
310
|
|
|
172
311
|
if rms_norm:
|
|
173
312
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -211,15 +350,27 @@ def apply_liger_kernel_to_mixtral(
|
|
|
211
350
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
212
351
|
|
|
213
352
|
from transformers.models.mixtral import modeling_mixtral
|
|
353
|
+
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|
214
354
|
|
|
215
355
|
if rope:
|
|
216
356
|
modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
217
357
|
if rms_norm:
|
|
218
358
|
modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
|
|
219
359
|
if cross_entropy:
|
|
220
|
-
|
|
360
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
361
|
+
from transformers.loss.loss_utils import nn
|
|
362
|
+
|
|
363
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
364
|
+
else:
|
|
365
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
366
|
+
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
367
|
+
|
|
221
368
|
if fused_linear_cross_entropy:
|
|
222
|
-
|
|
369
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
370
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
371
|
+
else: # if version < 4.46.1
|
|
372
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
373
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
223
374
|
if swiglu:
|
|
224
375
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
225
376
|
|
|
@@ -227,12 +378,8 @@ def apply_liger_kernel_to_mixtral(
|
|
|
227
378
|
# The model instance already exists, so we need to additionally patch the
|
|
228
379
|
# instance variables that reference already-instantiated modules
|
|
229
380
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
base_model = model.model
|
|
233
|
-
else:
|
|
234
|
-
# Direct MixtralModel
|
|
235
|
-
base_model = model
|
|
381
|
+
# get the base model from the model instance
|
|
382
|
+
base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
|
|
236
383
|
|
|
237
384
|
if rms_norm:
|
|
238
385
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -277,6 +424,7 @@ def apply_liger_kernel_to_gemma(
|
|
|
277
424
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
278
425
|
|
|
279
426
|
from transformers.models.gemma import modeling_gemma
|
|
427
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
280
428
|
|
|
281
429
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
282
430
|
LigerRMSNormForGemma = partial(
|
|
@@ -291,22 +439,28 @@ def apply_liger_kernel_to_gemma(
|
|
|
291
439
|
if rms_norm:
|
|
292
440
|
modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
|
|
293
441
|
if cross_entropy:
|
|
294
|
-
|
|
442
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
443
|
+
from transformers.loss.loss_utils import nn
|
|
444
|
+
|
|
445
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
446
|
+
else:
|
|
447
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
448
|
+
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
295
449
|
if geglu:
|
|
296
450
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
297
451
|
if fused_linear_cross_entropy:
|
|
298
|
-
|
|
452
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
453
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
454
|
+
else: # if version < 4.46.1
|
|
455
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
456
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
299
457
|
|
|
300
458
|
if model is not None:
|
|
301
459
|
# The model instance already exists, so we need to additionally patch the
|
|
302
460
|
# instance variables that reference already-instantiated modules
|
|
303
461
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
base_model = model.model
|
|
307
|
-
else:
|
|
308
|
-
# Direct GemmaModel
|
|
309
|
-
base_model = model
|
|
462
|
+
# get the base model from the model instance
|
|
463
|
+
base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
|
|
310
464
|
|
|
311
465
|
if rms_norm:
|
|
312
466
|
_patch_rms_norm_module_for_gemma(base_model.norm)
|
|
@@ -323,7 +477,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
323
477
|
|
|
324
478
|
def apply_liger_kernel_to_gemma2(
|
|
325
479
|
rope: bool = True,
|
|
326
|
-
cross_entropy: bool =
|
|
480
|
+
cross_entropy: bool = False,
|
|
481
|
+
fused_linear_cross_entropy: bool = True,
|
|
327
482
|
rms_norm: bool = True,
|
|
328
483
|
geglu: bool = True,
|
|
329
484
|
model: PreTrainedModel = None,
|
|
@@ -334,16 +489,25 @@ def apply_liger_kernel_to_gemma2(
|
|
|
334
489
|
|
|
335
490
|
Args:
|
|
336
491
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
337
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
492
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
493
|
+
fused_linear_cross_entropy (bool):
|
|
494
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
495
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
496
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
338
497
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
339
498
|
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
340
499
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
341
500
|
loaded. Default is None.
|
|
342
501
|
"""
|
|
502
|
+
assert not (
|
|
503
|
+
cross_entropy and fused_linear_cross_entropy
|
|
504
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
505
|
+
|
|
343
506
|
from transformers.models.gemma2 import modeling_gemma2
|
|
507
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
344
508
|
|
|
345
509
|
LigerRMSNormForGemma2 = partial(
|
|
346
|
-
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
|
|
510
|
+
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
|
|
347
511
|
)
|
|
348
512
|
_patch_rms_norm_module_for_gemma2 = partial(
|
|
349
513
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma"
|
|
@@ -355,7 +519,19 @@ def apply_liger_kernel_to_gemma2(
|
|
|
355
519
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
356
520
|
modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
|
|
357
521
|
if cross_entropy:
|
|
358
|
-
|
|
522
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
523
|
+
from transformers.loss.loss_utils import nn
|
|
524
|
+
|
|
525
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
526
|
+
else:
|
|
527
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
528
|
+
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
529
|
+
if fused_linear_cross_entropy:
|
|
530
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
531
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
532
|
+
else:
|
|
533
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
534
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
359
535
|
if geglu:
|
|
360
536
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
361
537
|
|
|
@@ -363,12 +539,8 @@ def apply_liger_kernel_to_gemma2(
|
|
|
363
539
|
# The model instance already exists, so we need to additionally patch the
|
|
364
540
|
# instance variables that reference already-instantiated modules
|
|
365
541
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
base_model = model.model
|
|
369
|
-
else:
|
|
370
|
-
# Direct Gemma2Model
|
|
371
|
-
base_model = model
|
|
542
|
+
# get the base model from the model instance
|
|
543
|
+
base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
|
|
372
544
|
|
|
373
545
|
if rms_norm:
|
|
374
546
|
_patch_rms_norm_module_for_gemma2(base_model.norm)
|
|
@@ -419,15 +591,31 @@ def apply_liger_kernel_to_qwen2(
|
|
|
419
591
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
420
592
|
|
|
421
593
|
from transformers.models.qwen2 import modeling_qwen2
|
|
594
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
|
422
595
|
|
|
423
596
|
if rope:
|
|
424
597
|
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
425
598
|
if rms_norm:
|
|
426
599
|
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
600
|
+
|
|
427
601
|
if cross_entropy:
|
|
428
|
-
|
|
602
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
603
|
+
from transformers.loss.loss_utils import nn
|
|
604
|
+
|
|
605
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
606
|
+
else:
|
|
607
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
608
|
+
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
609
|
+
|
|
610
|
+
# import pdb; pdb.set_trace()
|
|
429
611
|
if fused_linear_cross_entropy:
|
|
430
|
-
|
|
612
|
+
|
|
613
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
614
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
615
|
+
else: # if version < 4.46.1
|
|
616
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
617
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
618
|
+
|
|
431
619
|
if swiglu:
|
|
432
620
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
433
621
|
|
|
@@ -435,12 +623,8 @@ def apply_liger_kernel_to_qwen2(
|
|
|
435
623
|
# The model instance already exists, so we need to additionally patch the
|
|
436
624
|
# instance variables that reference already-instantiated modules
|
|
437
625
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
base_model = model.model
|
|
441
|
-
else:
|
|
442
|
-
# Direct Qwen2Model
|
|
443
|
-
base_model = model
|
|
626
|
+
# get the base model from the model instance
|
|
627
|
+
base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
|
|
444
628
|
|
|
445
629
|
if rms_norm:
|
|
446
630
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -453,6 +637,7 @@ def apply_liger_kernel_to_qwen2(
|
|
|
453
637
|
if rms_norm:
|
|
454
638
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
455
639
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
640
|
+
print("Applied Liger kernels to Qwen2")
|
|
456
641
|
|
|
457
642
|
|
|
458
643
|
def apply_liger_kernel_to_qwen2_vl(
|
|
@@ -465,7 +650,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
465
650
|
) -> None:
|
|
466
651
|
"""
|
|
467
652
|
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
|
468
|
-
NOTE: Qwen2-VL is not available in transformers
|
|
653
|
+
NOTE: Qwen2-VL is not available in transformers<4.45.0
|
|
469
654
|
|
|
470
655
|
Args:
|
|
471
656
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
@@ -484,6 +669,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
484
669
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
485
670
|
|
|
486
671
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
672
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
487
673
|
|
|
488
674
|
from liger_kernel.transformers.model.qwen2_vl import (
|
|
489
675
|
lce_forward as qwen2_vl_lce_forward,
|
|
@@ -507,12 +693,8 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
507
693
|
# The model instance already exists, so we need to additionally patch the
|
|
508
694
|
# instance variables that reference already-instantiated modules
|
|
509
695
|
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
base_model = model.model
|
|
513
|
-
else:
|
|
514
|
-
# Direct Qwen2VLModel
|
|
515
|
-
base_model = model
|
|
696
|
+
# get the base model from the model instance
|
|
697
|
+
base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
|
|
516
698
|
|
|
517
699
|
if hasattr(model, "visual"):
|
|
518
700
|
# Patch Qwen2VisionTransformerPretrainedModel
|
|
@@ -561,6 +743,7 @@ def apply_liger_kernel_to_phi3(
|
|
|
561
743
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
562
744
|
|
|
563
745
|
from transformers.models.phi3 import modeling_phi3
|
|
746
|
+
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
564
747
|
|
|
565
748
|
if rope:
|
|
566
749
|
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
@@ -569,20 +752,26 @@ def apply_liger_kernel_to_phi3(
|
|
|
569
752
|
if swiglu:
|
|
570
753
|
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
571
754
|
if cross_entropy:
|
|
572
|
-
|
|
755
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
756
|
+
from transformers.loss.loss_utils import nn
|
|
757
|
+
|
|
758
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
759
|
+
else:
|
|
760
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
761
|
+
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
573
762
|
if fused_linear_cross_entropy:
|
|
574
|
-
|
|
763
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
764
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
765
|
+
else: # if version < 4.46.1
|
|
766
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
767
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
|
|
575
768
|
|
|
576
769
|
if model is not None:
|
|
577
770
|
# The model instance already exists, so we need to additionally patch the
|
|
578
771
|
# instance variables that reference already-instantiated modules
|
|
579
772
|
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
base_model = model.model
|
|
583
|
-
else:
|
|
584
|
-
# Direct Phi3Model
|
|
585
|
-
base_model = model
|
|
773
|
+
# get the base model from the model instance
|
|
774
|
+
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
|
|
586
775
|
|
|
587
776
|
if rms_norm:
|
|
588
777
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -602,6 +791,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
602
791
|
"gemma": apply_liger_kernel_to_gemma,
|
|
603
792
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
604
793
|
"llama": apply_liger_kernel_to_llama,
|
|
794
|
+
"mllama": apply_liger_kernel_to_mllama,
|
|
795
|
+
"mllama_text_model": apply_liger_kernel_to_mllama,
|
|
605
796
|
"mistral": apply_liger_kernel_to_mistral,
|
|
606
797
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
607
798
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
@@ -687,7 +878,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
|
687
878
|
for key, value in kwargs.items()
|
|
688
879
|
if key in apply_fn_signature.parameters
|
|
689
880
|
}
|
|
690
|
-
|
|
691
881
|
logger.info(
|
|
692
882
|
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
|
693
883
|
)
|
|
@@ -6,7 +6,13 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
|
6
6
|
|
|
7
7
|
class LigerRMSNorm(nn.Module):
|
|
8
8
|
def __init__(
|
|
9
|
-
self,
|
|
9
|
+
self,
|
|
10
|
+
hidden_size,
|
|
11
|
+
eps=1e-6,
|
|
12
|
+
offset=0.0,
|
|
13
|
+
casting_mode="llama",
|
|
14
|
+
init_fn="ones",
|
|
15
|
+
in_place=True,
|
|
10
16
|
):
|
|
11
17
|
super().__init__()
|
|
12
18
|
assert init_fn in [
|
|
@@ -16,10 +22,11 @@ class LigerRMSNorm(nn.Module):
|
|
|
16
22
|
self.weight = nn.Parameter(
|
|
17
23
|
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
|
18
24
|
)
|
|
19
|
-
self.variance_epsilon, self.offset, self.casting_mode = (
|
|
25
|
+
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
|
|
20
26
|
eps,
|
|
21
27
|
offset,
|
|
22
28
|
casting_mode,
|
|
29
|
+
in_place,
|
|
23
30
|
)
|
|
24
31
|
|
|
25
32
|
def forward(self, hidden_states):
|
|
@@ -29,7 +36,8 @@ class LigerRMSNorm(nn.Module):
|
|
|
29
36
|
self.variance_epsilon,
|
|
30
37
|
self.offset,
|
|
31
38
|
self.casting_mode,
|
|
39
|
+
self.in_place,
|
|
32
40
|
)
|
|
33
41
|
|
|
34
42
|
def extra_repr(self):
|
|
35
|
-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}"
|
|
43
|
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
|