liger-kernel 0.1.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +46 -0
- liger_kernel/ops/cross_entropy.py +130 -63
- liger_kernel/ops/experimental/embedding.py +143 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
- liger_kernel/ops/geglu.py +56 -44
- liger_kernel/ops/kl_div.py +258 -0
- liger_kernel/ops/layer_norm.py +236 -0
- liger_kernel/ops/rms_norm.py +220 -84
- liger_kernel/ops/rope.py +91 -84
- liger_kernel/ops/swiglu.py +50 -43
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +22 -0
- liger_kernel/transformers/auto_model.py +45 -0
- liger_kernel/transformers/cross_entropy.py +11 -1
- liger_kernel/transformers/experimental/embedding.py +28 -0
- liger_kernel/transformers/functional.py +19 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
- liger_kernel/transformers/geglu.py +4 -2
- liger_kernel/transformers/kl_div.py +14 -0
- liger_kernel/transformers/layer_norm.py +30 -0
- liger_kernel/transformers/model/gemma.py +138 -0
- liger_kernel/transformers/model/llama.py +1 -1
- liger_kernel/transformers/model/mistral.py +138 -0
- liger_kernel/transformers/model/mixtral.py +158 -0
- liger_kernel/transformers/model/phi3.py +136 -0
- liger_kernel/transformers/model/qwen2.py +135 -0
- liger_kernel/transformers/model/qwen2_vl.py +172 -0
- liger_kernel/transformers/monkey_patch.py +579 -14
- liger_kernel/transformers/rms_norm.py +23 -4
- liger_kernel/transformers/swiglu.py +24 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel-0.3.1.dist-info/METADATA +395 -0
- liger_kernel-0.3.1.dist-info/RECORD +42 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/WHEEL +1 -1
- liger_kernel-0.1.0.dist-info/METADATA +0 -16
- liger_kernel-0.1.0.dist-info/RECORD +0 -27
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/LICENSE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/NOTICE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,52 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import logging
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from transformers import PreTrainedModel
|
|
7
|
+
|
|
1
8
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
2
9
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
3
|
-
from liger_kernel.transformers.
|
|
10
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
11
|
+
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
12
|
+
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
13
|
+
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
|
14
|
+
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
15
|
+
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
16
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
4
17
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
5
18
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
6
|
-
from liger_kernel.transformers.swiglu import
|
|
19
|
+
from liger_kernel.transformers.swiglu import (
|
|
20
|
+
LigerBlockSparseTop2MLP,
|
|
21
|
+
LigerPhi3SwiGLUMLP,
|
|
22
|
+
LigerSwiGLUMLP,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
29
|
+
# Binds a new method to a module instance so that self is passed as the first argument
|
|
30
|
+
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"):
|
|
34
|
+
module.offset = offset
|
|
35
|
+
module.casting_mode = casting_mode
|
|
36
|
+
module.variance_epsilon = (
|
|
37
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
38
|
+
)
|
|
39
|
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
40
|
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _patch_layer_norm_module(module, eps=1e-6):
|
|
44
|
+
module.variance_epsilon = (
|
|
45
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
46
|
+
)
|
|
47
|
+
module.hidden_size = module.normalized_shape
|
|
48
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
49
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
7
50
|
|
|
8
51
|
|
|
9
52
|
def apply_liger_kernel_to_llama(
|
|
@@ -12,6 +55,7 @@ def apply_liger_kernel_to_llama(
|
|
|
12
55
|
fused_linear_cross_entropy: bool = True,
|
|
13
56
|
rms_norm: bool = True,
|
|
14
57
|
swiglu: bool = True,
|
|
58
|
+
model: PreTrainedModel = None,
|
|
15
59
|
) -> None:
|
|
16
60
|
"""
|
|
17
61
|
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
|
@@ -20,11 +64,13 @@ def apply_liger_kernel_to_llama(
|
|
|
20
64
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
21
65
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
22
66
|
fused_linear_cross_entropy (bool):
|
|
23
|
-
Whether to apply Liger's fused
|
|
67
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
24
68
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
25
69
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
26
70
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
27
71
|
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
72
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
73
|
+
loaded. Default is None.
|
|
28
74
|
"""
|
|
29
75
|
|
|
30
76
|
assert not (
|
|
@@ -42,14 +88,42 @@ def apply_liger_kernel_to_llama(
|
|
|
42
88
|
if cross_entropy:
|
|
43
89
|
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
44
90
|
if fused_linear_cross_entropy:
|
|
45
|
-
modeling_llama.LlamaForCausalLM.forward =
|
|
91
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
92
|
+
|
|
93
|
+
if model is not None:
|
|
94
|
+
# The model instance already exists, so we need to additionally patch the
|
|
95
|
+
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
96
|
+
|
|
97
|
+
if hasattr(model, "model"):
|
|
98
|
+
# The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
|
|
99
|
+
base_model = model.model
|
|
100
|
+
elif hasattr(model, "transformer"):
|
|
101
|
+
# LlamaForQuestionAnswering uses "transformer" instead of "model"
|
|
102
|
+
base_model = model.transformer
|
|
103
|
+
else:
|
|
104
|
+
# Direct LlamaModel
|
|
105
|
+
base_model = model
|
|
106
|
+
|
|
107
|
+
if rms_norm:
|
|
108
|
+
_patch_rms_norm_module(base_model.norm)
|
|
109
|
+
|
|
110
|
+
for decoder_layer in base_model.layers:
|
|
111
|
+
if swiglu:
|
|
112
|
+
_bind_method_to_module(
|
|
113
|
+
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
114
|
+
)
|
|
115
|
+
if rms_norm:
|
|
116
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
117
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
46
118
|
|
|
47
119
|
|
|
48
120
|
def apply_liger_kernel_to_mistral(
|
|
49
121
|
rope: bool = True,
|
|
50
|
-
cross_entropy: bool =
|
|
122
|
+
cross_entropy: bool = False,
|
|
123
|
+
fused_linear_cross_entropy: bool = True,
|
|
51
124
|
rms_norm: bool = True,
|
|
52
125
|
swiglu: bool = True,
|
|
126
|
+
model: PreTrainedModel = None,
|
|
53
127
|
) -> None:
|
|
54
128
|
"""
|
|
55
129
|
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
|
|
@@ -57,9 +131,19 @@ def apply_liger_kernel_to_mistral(
|
|
|
57
131
|
Args:
|
|
58
132
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
59
133
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
134
|
+
fused_linear_cross_entropy (bool):
|
|
135
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
136
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
137
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
138
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
60
139
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
61
140
|
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
141
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
142
|
+
loaded. Default is None.
|
|
62
143
|
"""
|
|
144
|
+
assert not (
|
|
145
|
+
cross_entropy and fused_linear_cross_entropy
|
|
146
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
63
147
|
|
|
64
148
|
from transformers.models.mistral import modeling_mistral
|
|
65
149
|
|
|
@@ -69,62 +153,543 @@ def apply_liger_kernel_to_mistral(
|
|
|
69
153
|
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
|
70
154
|
if cross_entropy:
|
|
71
155
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
156
|
+
if fused_linear_cross_entropy:
|
|
157
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
72
158
|
if swiglu:
|
|
73
159
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
74
160
|
|
|
161
|
+
if model is not None:
|
|
162
|
+
# The model instance already exists, so we need to additionally patch the
|
|
163
|
+
# instance variables that reference already-instantiated modules
|
|
164
|
+
|
|
165
|
+
if hasattr(model, "model"):
|
|
166
|
+
# The case for MistralForCausalLM, MistralForTokenClassification for example
|
|
167
|
+
base_model = model.model
|
|
168
|
+
else:
|
|
169
|
+
# Direct MistralModel
|
|
170
|
+
base_model = model
|
|
171
|
+
|
|
172
|
+
if rms_norm:
|
|
173
|
+
_patch_rms_norm_module(base_model.norm)
|
|
174
|
+
|
|
175
|
+
for decoder_layer in base_model.layers:
|
|
176
|
+
if swiglu:
|
|
177
|
+
_bind_method_to_module(
|
|
178
|
+
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
179
|
+
)
|
|
180
|
+
if rms_norm:
|
|
181
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
182
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
183
|
+
|
|
75
184
|
|
|
76
185
|
def apply_liger_kernel_to_mixtral(
|
|
77
186
|
rope: bool = True,
|
|
78
|
-
cross_entropy: bool =
|
|
187
|
+
cross_entropy: bool = False,
|
|
188
|
+
fused_linear_cross_entropy: bool = True,
|
|
79
189
|
rms_norm: bool = True,
|
|
80
190
|
swiglu: bool = True,
|
|
191
|
+
model: PreTrainedModel = None,
|
|
81
192
|
) -> None:
|
|
82
193
|
"""
|
|
83
194
|
Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
|
|
84
195
|
|
|
85
196
|
Args:
|
|
86
197
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
87
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
198
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
199
|
+
fused_linear_cross_entropy (bool):
|
|
200
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
201
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
202
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
88
203
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
89
204
|
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
205
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
206
|
+
loaded. Default is None.
|
|
90
207
|
"""
|
|
91
208
|
|
|
209
|
+
assert not (
|
|
210
|
+
cross_entropy and fused_linear_cross_entropy
|
|
211
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
212
|
+
|
|
92
213
|
from transformers.models.mixtral import modeling_mixtral
|
|
93
214
|
|
|
94
215
|
if rope:
|
|
95
216
|
modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
96
217
|
if rms_norm:
|
|
97
|
-
modeling_mixtral.
|
|
218
|
+
modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
|
|
98
219
|
if cross_entropy:
|
|
99
220
|
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
221
|
+
if fused_linear_cross_entropy:
|
|
222
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
100
223
|
if swiglu:
|
|
101
224
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
102
225
|
|
|
226
|
+
if model is not None:
|
|
227
|
+
# The model instance already exists, so we need to additionally patch the
|
|
228
|
+
# instance variables that reference already-instantiated modules
|
|
229
|
+
|
|
230
|
+
if hasattr(model, "model"):
|
|
231
|
+
# The case for MixtralForCausalLM, MixtralForTokenClassification for example
|
|
232
|
+
base_model = model.model
|
|
233
|
+
else:
|
|
234
|
+
# Direct MixtralModel
|
|
235
|
+
base_model = model
|
|
236
|
+
|
|
237
|
+
if rms_norm:
|
|
238
|
+
_patch_rms_norm_module(base_model.norm)
|
|
239
|
+
|
|
240
|
+
for decoder_layer in base_model.layers:
|
|
241
|
+
if swiglu:
|
|
242
|
+
for expert in decoder_layer.block_sparse_moe.experts:
|
|
243
|
+
_bind_method_to_module(
|
|
244
|
+
expert, "forward", LigerBlockSparseTop2MLP.forward
|
|
245
|
+
)
|
|
246
|
+
if rms_norm:
|
|
247
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
248
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
249
|
+
|
|
103
250
|
|
|
104
251
|
def apply_liger_kernel_to_gemma(
|
|
105
252
|
rope: bool = True,
|
|
106
|
-
cross_entropy: bool =
|
|
253
|
+
cross_entropy: bool = False,
|
|
254
|
+
fused_linear_cross_entropy: bool = True,
|
|
107
255
|
rms_norm: bool = True,
|
|
108
256
|
geglu: bool = True,
|
|
257
|
+
model: PreTrainedModel = None,
|
|
109
258
|
) -> None:
|
|
110
259
|
"""
|
|
111
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
112
|
-
to make GPU go burrr.
|
|
260
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma
|
|
261
|
+
(Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
|
|
113
262
|
|
|
114
263
|
Args:
|
|
115
264
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
116
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
265
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
266
|
+
fused_linear_cross_entropy (bool):
|
|
267
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
268
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
269
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
117
270
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
118
271
|
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
272
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
273
|
+
loaded. Default is None.
|
|
119
274
|
"""
|
|
120
|
-
|
|
275
|
+
assert not (
|
|
276
|
+
cross_entropy and fused_linear_cross_entropy
|
|
277
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
278
|
+
|
|
121
279
|
from transformers.models.gemma import modeling_gemma
|
|
122
280
|
|
|
281
|
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
282
|
+
LigerRMSNormForGemma = partial(
|
|
283
|
+
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
|
284
|
+
)
|
|
285
|
+
_patch_rms_norm_module_for_gemma = partial(
|
|
286
|
+
_patch_rms_norm_module, casting_mode="gemma", offset=1.0
|
|
287
|
+
)
|
|
288
|
+
|
|
123
289
|
if rope:
|
|
124
290
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
125
291
|
if rms_norm:
|
|
126
|
-
modeling_gemma.GemmaRMSNorm =
|
|
292
|
+
modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
|
|
127
293
|
if cross_entropy:
|
|
128
294
|
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
129
295
|
if geglu:
|
|
130
296
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
297
|
+
if fused_linear_cross_entropy:
|
|
298
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
299
|
+
|
|
300
|
+
if model is not None:
|
|
301
|
+
# The model instance already exists, so we need to additionally patch the
|
|
302
|
+
# instance variables that reference already-instantiated modules
|
|
303
|
+
|
|
304
|
+
if hasattr(model, "model"):
|
|
305
|
+
# The case for GemmaForCausalLM, GemmaForTokenClassification for example
|
|
306
|
+
base_model = model.model
|
|
307
|
+
else:
|
|
308
|
+
# Direct GemmaModel
|
|
309
|
+
base_model = model
|
|
310
|
+
|
|
311
|
+
if rms_norm:
|
|
312
|
+
_patch_rms_norm_module_for_gemma(base_model.norm)
|
|
313
|
+
|
|
314
|
+
for decoder_layer in base_model.layers:
|
|
315
|
+
if geglu:
|
|
316
|
+
_bind_method_to_module(
|
|
317
|
+
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
|
318
|
+
)
|
|
319
|
+
if rms_norm:
|
|
320
|
+
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
|
321
|
+
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def apply_liger_kernel_to_gemma2(
|
|
325
|
+
rope: bool = True,
|
|
326
|
+
cross_entropy: bool = True,
|
|
327
|
+
rms_norm: bool = True,
|
|
328
|
+
geglu: bool = True,
|
|
329
|
+
model: PreTrainedModel = None,
|
|
330
|
+
) -> None:
|
|
331
|
+
"""
|
|
332
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma2
|
|
333
|
+
(for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
337
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
338
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
339
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
340
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
341
|
+
loaded. Default is None.
|
|
342
|
+
"""
|
|
343
|
+
from transformers.models.gemma2 import modeling_gemma2
|
|
344
|
+
|
|
345
|
+
LigerRMSNormForGemma2 = partial(
|
|
346
|
+
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
|
|
347
|
+
)
|
|
348
|
+
_patch_rms_norm_module_for_gemma2 = partial(
|
|
349
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
if rope:
|
|
353
|
+
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
354
|
+
if rms_norm:
|
|
355
|
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
356
|
+
modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
|
|
357
|
+
if cross_entropy:
|
|
358
|
+
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
359
|
+
if geglu:
|
|
360
|
+
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
361
|
+
|
|
362
|
+
if model is not None:
|
|
363
|
+
# The model instance already exists, so we need to additionally patch the
|
|
364
|
+
# instance variables that reference already-instantiated modules
|
|
365
|
+
|
|
366
|
+
if hasattr(model, "model"):
|
|
367
|
+
# The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
|
|
368
|
+
base_model = model.model
|
|
369
|
+
else:
|
|
370
|
+
# Direct Gemma2Model
|
|
371
|
+
base_model = model
|
|
372
|
+
|
|
373
|
+
if rms_norm:
|
|
374
|
+
_patch_rms_norm_module_for_gemma2(base_model.norm)
|
|
375
|
+
|
|
376
|
+
for decoder_layer in base_model.layers:
|
|
377
|
+
if geglu:
|
|
378
|
+
_bind_method_to_module(
|
|
379
|
+
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
|
380
|
+
)
|
|
381
|
+
if rms_norm:
|
|
382
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
|
383
|
+
_patch_rms_norm_module_for_gemma2(
|
|
384
|
+
decoder_layer.post_attention_layernorm
|
|
385
|
+
)
|
|
386
|
+
_patch_rms_norm_module_for_gemma2(
|
|
387
|
+
decoder_layer.pre_feedforward_layernorm
|
|
388
|
+
)
|
|
389
|
+
_patch_rms_norm_module_for_gemma2(
|
|
390
|
+
decoder_layer.post_feedforward_layernorm
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def apply_liger_kernel_to_qwen2(
|
|
395
|
+
rope: bool = True,
|
|
396
|
+
cross_entropy: bool = False,
|
|
397
|
+
fused_linear_cross_entropy: bool = True,
|
|
398
|
+
rms_norm: bool = True,
|
|
399
|
+
swiglu: bool = True,
|
|
400
|
+
model: PreTrainedModel = None,
|
|
401
|
+
) -> None:
|
|
402
|
+
"""
|
|
403
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
407
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
408
|
+
fused_linear_cross_entropy (bool):
|
|
409
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
410
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
411
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
412
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
413
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
414
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
415
|
+
loaded. Default is None.
|
|
416
|
+
"""
|
|
417
|
+
assert not (
|
|
418
|
+
cross_entropy and fused_linear_cross_entropy
|
|
419
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
420
|
+
|
|
421
|
+
from transformers.models.qwen2 import modeling_qwen2
|
|
422
|
+
|
|
423
|
+
if rope:
|
|
424
|
+
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
425
|
+
if rms_norm:
|
|
426
|
+
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
427
|
+
if cross_entropy:
|
|
428
|
+
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
429
|
+
if fused_linear_cross_entropy:
|
|
430
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
431
|
+
if swiglu:
|
|
432
|
+
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
433
|
+
|
|
434
|
+
if model is not None:
|
|
435
|
+
# The model instance already exists, so we need to additionally patch the
|
|
436
|
+
# instance variables that reference already-instantiated modules
|
|
437
|
+
|
|
438
|
+
if hasattr(model, "model"):
|
|
439
|
+
# The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
|
|
440
|
+
base_model = model.model
|
|
441
|
+
else:
|
|
442
|
+
# Direct Qwen2Model
|
|
443
|
+
base_model = model
|
|
444
|
+
|
|
445
|
+
if rms_norm:
|
|
446
|
+
_patch_rms_norm_module(base_model.norm)
|
|
447
|
+
|
|
448
|
+
for decoder_layer in base_model.layers:
|
|
449
|
+
if swiglu:
|
|
450
|
+
_bind_method_to_module(
|
|
451
|
+
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
452
|
+
)
|
|
453
|
+
if rms_norm:
|
|
454
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
455
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def apply_liger_kernel_to_qwen2_vl(
|
|
459
|
+
cross_entropy: bool = False,
|
|
460
|
+
fused_linear_cross_entropy: bool = True,
|
|
461
|
+
rms_norm: bool = True,
|
|
462
|
+
layer_norm: bool = True,
|
|
463
|
+
swiglu: bool = True,
|
|
464
|
+
model: PreTrainedModel = None,
|
|
465
|
+
) -> None:
|
|
466
|
+
"""
|
|
467
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
|
468
|
+
NOTE: Qwen2-VL is not available in transformers<=4.44.2
|
|
469
|
+
|
|
470
|
+
Args:
|
|
471
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
472
|
+
fused_linear_cross_entropy (bool):
|
|
473
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
474
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
475
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
476
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
477
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
478
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
479
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
480
|
+
loaded. Default is None.
|
|
481
|
+
"""
|
|
482
|
+
assert not (
|
|
483
|
+
cross_entropy and fused_linear_cross_entropy
|
|
484
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
485
|
+
|
|
486
|
+
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
487
|
+
|
|
488
|
+
from liger_kernel.transformers.model.qwen2_vl import (
|
|
489
|
+
lce_forward as qwen2_vl_lce_forward,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# TODO: Support Qwen2-VL's multimodal RoPE implementation
|
|
493
|
+
|
|
494
|
+
if rms_norm:
|
|
495
|
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
496
|
+
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
497
|
+
if layer_norm:
|
|
498
|
+
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
499
|
+
if cross_entropy:
|
|
500
|
+
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
501
|
+
if fused_linear_cross_entropy:
|
|
502
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
503
|
+
if swiglu:
|
|
504
|
+
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
505
|
+
|
|
506
|
+
if model is not None:
|
|
507
|
+
# The model instance already exists, so we need to additionally patch the
|
|
508
|
+
# instance variables that reference already-instantiated modules
|
|
509
|
+
|
|
510
|
+
if hasattr(model, "model"):
|
|
511
|
+
# The case for Qwen2VLForConditionalGeneration.
|
|
512
|
+
base_model = model.model
|
|
513
|
+
else:
|
|
514
|
+
# Direct Qwen2VLModel
|
|
515
|
+
base_model = model
|
|
516
|
+
|
|
517
|
+
if hasattr(model, "visual"):
|
|
518
|
+
# Patch Qwen2VisionTransformerPretrainedModel
|
|
519
|
+
for vision_block in model.visual.blocks:
|
|
520
|
+
if layer_norm:
|
|
521
|
+
_patch_layer_norm_module(vision_block.norm1)
|
|
522
|
+
_patch_layer_norm_module(vision_block.norm2)
|
|
523
|
+
|
|
524
|
+
if rms_norm:
|
|
525
|
+
_patch_rms_norm_module(base_model.norm)
|
|
526
|
+
for decoder_layer in base_model.layers:
|
|
527
|
+
if swiglu:
|
|
528
|
+
_bind_method_to_module(
|
|
529
|
+
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
530
|
+
)
|
|
531
|
+
if rms_norm:
|
|
532
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
533
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def apply_liger_kernel_to_phi3(
|
|
537
|
+
rope: bool = True,
|
|
538
|
+
cross_entropy: bool = False,
|
|
539
|
+
fused_linear_cross_entropy: bool = True,
|
|
540
|
+
rms_norm: bool = True,
|
|
541
|
+
swiglu: bool = True,
|
|
542
|
+
model: PreTrainedModel = None,
|
|
543
|
+
) -> None:
|
|
544
|
+
"""
|
|
545
|
+
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
549
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
550
|
+
fused_linear_cross_entropy (bool):
|
|
551
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
552
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
553
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
554
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
555
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
|
|
556
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
557
|
+
loaded. Default is None.
|
|
558
|
+
"""
|
|
559
|
+
assert not (
|
|
560
|
+
cross_entropy and fused_linear_cross_entropy
|
|
561
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
562
|
+
|
|
563
|
+
from transformers.models.phi3 import modeling_phi3
|
|
564
|
+
|
|
565
|
+
if rope:
|
|
566
|
+
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
567
|
+
if rms_norm:
|
|
568
|
+
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
|
|
569
|
+
if swiglu:
|
|
570
|
+
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
571
|
+
if cross_entropy:
|
|
572
|
+
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
573
|
+
if fused_linear_cross_entropy:
|
|
574
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
575
|
+
|
|
576
|
+
if model is not None:
|
|
577
|
+
# The model instance already exists, so we need to additionally patch the
|
|
578
|
+
# instance variables that reference already-instantiated modules
|
|
579
|
+
|
|
580
|
+
if hasattr(model, "model"):
|
|
581
|
+
# The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
|
|
582
|
+
base_model = model.model
|
|
583
|
+
else:
|
|
584
|
+
# Direct Phi3Model
|
|
585
|
+
base_model = model
|
|
586
|
+
|
|
587
|
+
if rms_norm:
|
|
588
|
+
_patch_rms_norm_module(base_model.norm)
|
|
589
|
+
|
|
590
|
+
for decoder_layer in base_model.layers:
|
|
591
|
+
if swiglu:
|
|
592
|
+
_bind_method_to_module(
|
|
593
|
+
decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
|
|
594
|
+
)
|
|
595
|
+
if rms_norm:
|
|
596
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
597
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
601
|
+
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
602
|
+
"gemma": apply_liger_kernel_to_gemma,
|
|
603
|
+
"gemma2": apply_liger_kernel_to_gemma2,
|
|
604
|
+
"llama": apply_liger_kernel_to_llama,
|
|
605
|
+
"mistral": apply_liger_kernel_to_mistral,
|
|
606
|
+
"mixtral": apply_liger_kernel_to_mixtral,
|
|
607
|
+
"qwen2": apply_liger_kernel_to_qwen2,
|
|
608
|
+
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
609
|
+
"phi3": apply_liger_kernel_to_phi3,
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def _apply_liger_kernel(model_type: str, **kwargs) -> None:
|
|
614
|
+
"""
|
|
615
|
+
Applies Liger kernels based on the specified model type. The custom
|
|
616
|
+
kernels for the specified model type will be applied with the provided
|
|
617
|
+
keyword arguments, otherwise the default configuration will be used.
|
|
618
|
+
|
|
619
|
+
** Note: Calling _apply_liger_kernel() after model initialization
|
|
620
|
+
will not be able to fully patch models. This must be called before model initialization.
|
|
621
|
+
If the model has already been instantiated
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
- model_type: the model types as defined in transformers/models/auto/modeling_auto.py
|
|
625
|
+
and specified in the model's config.json
|
|
626
|
+
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
627
|
+
"""
|
|
628
|
+
if not model_type:
|
|
629
|
+
logger.info("Model type was not provided. No Liger kernels will be applied.")
|
|
630
|
+
return
|
|
631
|
+
|
|
632
|
+
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
633
|
+
logger.info(
|
|
634
|
+
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
635
|
+
)
|
|
636
|
+
return
|
|
637
|
+
|
|
638
|
+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
639
|
+
apply_fn_signature = inspect.signature(apply_fn)
|
|
640
|
+
|
|
641
|
+
# Filter out the keyword arguments that are not supported by the apply function
|
|
642
|
+
applicable_kwargs = {
|
|
643
|
+
key: value
|
|
644
|
+
for key, value in kwargs.items()
|
|
645
|
+
if key in apply_fn_signature.parameters
|
|
646
|
+
}
|
|
647
|
+
|
|
648
|
+
logger.info(
|
|
649
|
+
f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# Assume this is invoked pre-model initialization, so we only need to patch transformers code
|
|
653
|
+
apply_fn(**applicable_kwargs)
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
657
|
+
"""
|
|
658
|
+
Applies Liger kernels to the provided model instance.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
- model: the model instance to apply Liger kernels to
|
|
662
|
+
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
663
|
+
"""
|
|
664
|
+
model_type = getattr(model, "config", None) and getattr(
|
|
665
|
+
model.config, "model_type", None
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
if not model_type:
|
|
669
|
+
logger.info(
|
|
670
|
+
"Model type could not be determined from model config. No Liger kernels will be applied."
|
|
671
|
+
)
|
|
672
|
+
return
|
|
673
|
+
|
|
674
|
+
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
675
|
+
logger.info(
|
|
676
|
+
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
677
|
+
)
|
|
678
|
+
return
|
|
679
|
+
|
|
680
|
+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
681
|
+
|
|
682
|
+
apply_fn_signature = inspect.signature(apply_fn)
|
|
683
|
+
|
|
684
|
+
# Filter out the keyword arguments that are not supported by the apply function
|
|
685
|
+
applicable_kwargs = {
|
|
686
|
+
key: value
|
|
687
|
+
for key, value in kwargs.items()
|
|
688
|
+
if key in apply_fn_signature.parameters
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
logger.info(
|
|
692
|
+
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
apply_fn(model=model, **applicable_kwargs)
|