liger-kernel 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +46 -0
- liger_kernel/ops/cross_entropy.py +5 -5
- liger_kernel/ops/fused_linear_cross_entropy.py +50 -21
- liger_kernel/ops/geglu.py +6 -1
- liger_kernel/ops/rms_norm.py +142 -20
- liger_kernel/ops/rope.py +3 -3
- liger_kernel/transformers/__init__.py +6 -0
- liger_kernel/transformers/auto_model.py +33 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -2
- liger_kernel/transformers/geglu.py +4 -2
- liger_kernel/transformers/model/gemma.py +138 -0
- liger_kernel/transformers/model/llama.py +1 -1
- liger_kernel/transformers/model/mistral.py +138 -0
- liger_kernel/transformers/model/phi3.py +136 -0
- liger_kernel/transformers/model/qwen2.py +135 -0
- liger_kernel/transformers/monkey_patch.py +203 -10
- liger_kernel/transformers/rms_norm.py +20 -4
- liger_kernel/transformers/swiglu.py +24 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel-0.2.0.dist-info/METADATA +307 -0
- liger_kernel-0.2.0.dist-info/RECORD +33 -0
- liger_kernel-0.1.0.dist-info/METADATA +0 -16
- liger_kernel-0.1.0.dist-info/RECORD +0 -27
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.2.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.2.0.dist-info}/NOTICE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.2.0.dist-info}/WHEEL +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,23 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import logging
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
1
5
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
2
6
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
3
|
-
from liger_kernel.transformers.model.
|
|
7
|
+
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
8
|
+
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
9
|
+
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
|
10
|
+
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
11
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
4
12
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
5
13
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
6
|
-
from liger_kernel.transformers.swiglu import
|
|
14
|
+
from liger_kernel.transformers.swiglu import (
|
|
15
|
+
LigerBlockSparseTop2MLP,
|
|
16
|
+
LigerPhi3SwiGLUMLP,
|
|
17
|
+
LigerSwiGLUMLP,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
7
21
|
|
|
8
22
|
|
|
9
23
|
def apply_liger_kernel_to_llama(
|
|
@@ -20,7 +34,7 @@ def apply_liger_kernel_to_llama(
|
|
|
20
34
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
21
35
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
22
36
|
fused_linear_cross_entropy (bool):
|
|
23
|
-
Whether to apply Liger's fused
|
|
37
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
24
38
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
25
39
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
26
40
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
@@ -42,12 +56,13 @@ def apply_liger_kernel_to_llama(
|
|
|
42
56
|
if cross_entropy:
|
|
43
57
|
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
44
58
|
if fused_linear_cross_entropy:
|
|
45
|
-
modeling_llama.LlamaForCausalLM.forward =
|
|
59
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
46
60
|
|
|
47
61
|
|
|
48
62
|
def apply_liger_kernel_to_mistral(
|
|
49
63
|
rope: bool = True,
|
|
50
|
-
cross_entropy: bool =
|
|
64
|
+
cross_entropy: bool = False,
|
|
65
|
+
fused_linear_cross_entropy: bool = True,
|
|
51
66
|
rms_norm: bool = True,
|
|
52
67
|
swiglu: bool = True,
|
|
53
68
|
) -> None:
|
|
@@ -57,9 +72,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
57
72
|
Args:
|
|
58
73
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
59
74
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
75
|
+
fused_linear_cross_entropy (bool):
|
|
76
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
77
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
78
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
79
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
60
80
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
61
81
|
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
62
82
|
"""
|
|
83
|
+
assert not (
|
|
84
|
+
cross_entropy and fused_linear_cross_entropy
|
|
85
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
63
86
|
|
|
64
87
|
from transformers.models.mistral import modeling_mistral
|
|
65
88
|
|
|
@@ -69,6 +92,8 @@ def apply_liger_kernel_to_mistral(
|
|
|
69
92
|
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
|
70
93
|
if cross_entropy:
|
|
71
94
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
95
|
+
if fused_linear_cross_entropy:
|
|
96
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
72
97
|
if swiglu:
|
|
73
98
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
74
99
|
|
|
@@ -94,7 +119,7 @@ def apply_liger_kernel_to_mixtral(
|
|
|
94
119
|
if rope:
|
|
95
120
|
modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
96
121
|
if rms_norm:
|
|
97
|
-
modeling_mixtral.
|
|
122
|
+
modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
|
|
98
123
|
if cross_entropy:
|
|
99
124
|
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
100
125
|
if swiglu:
|
|
@@ -104,12 +129,13 @@ def apply_liger_kernel_to_mixtral(
|
|
|
104
129
|
def apply_liger_kernel_to_gemma(
|
|
105
130
|
rope: bool = True,
|
|
106
131
|
cross_entropy: bool = True,
|
|
132
|
+
fused_linear_cross_entropy: bool = True,
|
|
107
133
|
rms_norm: bool = True,
|
|
108
134
|
geglu: bool = True,
|
|
109
135
|
) -> None:
|
|
110
136
|
"""
|
|
111
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
112
|
-
to make GPU go burrr.
|
|
137
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma
|
|
138
|
+
(Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
|
|
113
139
|
|
|
114
140
|
Args:
|
|
115
141
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
@@ -117,14 +143,181 @@ def apply_liger_kernel_to_gemma(
|
|
|
117
143
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
118
144
|
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
119
145
|
"""
|
|
120
|
-
|
|
146
|
+
assert not (
|
|
147
|
+
cross_entropy and fused_linear_cross_entropy
|
|
148
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
149
|
+
|
|
121
150
|
from transformers.models.gemma import modeling_gemma
|
|
122
151
|
|
|
123
152
|
if rope:
|
|
124
153
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
125
154
|
if rms_norm:
|
|
126
|
-
modeling_gemma.
|
|
155
|
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
156
|
+
modeling_gemma.GemmaRMSNorm = partial(
|
|
157
|
+
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
|
158
|
+
)
|
|
127
159
|
if cross_entropy:
|
|
128
160
|
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
129
161
|
if geglu:
|
|
130
162
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
163
|
+
if fused_linear_cross_entropy:
|
|
164
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def apply_liger_kernel_to_gemma2(
|
|
168
|
+
rope: bool = True,
|
|
169
|
+
cross_entropy: bool = True,
|
|
170
|
+
rms_norm: bool = True,
|
|
171
|
+
geglu: bool = True,
|
|
172
|
+
) -> None:
|
|
173
|
+
"""
|
|
174
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma2
|
|
175
|
+
(for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
179
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
180
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
181
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
182
|
+
"""
|
|
183
|
+
from transformers.models.gemma2 import modeling_gemma2
|
|
184
|
+
|
|
185
|
+
if rope:
|
|
186
|
+
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
187
|
+
if rms_norm:
|
|
188
|
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
189
|
+
modeling_gemma2.Gemma2RMSNorm = partial(
|
|
190
|
+
LigerRMSNorm, offset=1.0, init_fn="zeros"
|
|
191
|
+
)
|
|
192
|
+
if cross_entropy:
|
|
193
|
+
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
194
|
+
if geglu:
|
|
195
|
+
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def apply_liger_kernel_to_qwen2(
|
|
199
|
+
rope: bool = True,
|
|
200
|
+
cross_entropy: bool = False,
|
|
201
|
+
fused_linear_cross_entropy: bool = True,
|
|
202
|
+
rms_norm: bool = True,
|
|
203
|
+
swiglu: bool = True,
|
|
204
|
+
) -> None:
|
|
205
|
+
"""
|
|
206
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
210
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
211
|
+
fused_linear_cross_entropy (bool):
|
|
212
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
213
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
214
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
215
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
216
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
217
|
+
"""
|
|
218
|
+
assert not (
|
|
219
|
+
cross_entropy and fused_linear_cross_entropy
|
|
220
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
221
|
+
|
|
222
|
+
from transformers.models.qwen2 import modeling_qwen2
|
|
223
|
+
|
|
224
|
+
if rope:
|
|
225
|
+
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
226
|
+
if rms_norm:
|
|
227
|
+
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
228
|
+
if cross_entropy:
|
|
229
|
+
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
230
|
+
if fused_linear_cross_entropy:
|
|
231
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
232
|
+
if swiglu:
|
|
233
|
+
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def apply_liger_kernel_to_phi3(
|
|
237
|
+
rope: bool = True,
|
|
238
|
+
cross_entropy: bool = False,
|
|
239
|
+
fused_linear_cross_entropy: bool = True,
|
|
240
|
+
rms_norm: bool = True,
|
|
241
|
+
swiglu: bool = True,
|
|
242
|
+
) -> None:
|
|
243
|
+
"""
|
|
244
|
+
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
248
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
249
|
+
fused_linear_cross_entropy (bool):
|
|
250
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
251
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
252
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
253
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
254
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
|
|
255
|
+
"""
|
|
256
|
+
assert not (
|
|
257
|
+
cross_entropy and fused_linear_cross_entropy
|
|
258
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
259
|
+
|
|
260
|
+
from transformers.models.phi3 import modeling_phi3
|
|
261
|
+
|
|
262
|
+
if rope:
|
|
263
|
+
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
264
|
+
if rms_norm:
|
|
265
|
+
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
|
|
266
|
+
if swiglu:
|
|
267
|
+
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
268
|
+
if cross_entropy:
|
|
269
|
+
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
270
|
+
if fused_linear_cross_entropy:
|
|
271
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
275
|
+
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
276
|
+
"gemma": apply_liger_kernel_to_gemma,
|
|
277
|
+
"gemma2": apply_liger_kernel_to_gemma2,
|
|
278
|
+
"llama": apply_liger_kernel_to_llama,
|
|
279
|
+
"mistral": apply_liger_kernel_to_mistral,
|
|
280
|
+
"mixtral": apply_liger_kernel_to_mixtral,
|
|
281
|
+
"qwen2": apply_liger_kernel_to_qwen2,
|
|
282
|
+
"phi3": apply_liger_kernel_to_phi3,
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _apply_liger_kernel(model_type: str = "", **kwargs) -> None:
|
|
287
|
+
"""
|
|
288
|
+
Applies Liger kernels based on the specified model type. The custom
|
|
289
|
+
kernels for the specified model type will be applied with the provided
|
|
290
|
+
keyword arguments, otherwise the default configuration will be used.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
- model_type: the model types as defined in transformers/models/auto/modeling_auto.py
|
|
294
|
+
and specified in the model's config.json
|
|
295
|
+
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
if not model_type:
|
|
299
|
+
logger.info("Model type was not provided. No Liger kernels will be applied.")
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
303
|
+
logger.info(
|
|
304
|
+
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
305
|
+
)
|
|
306
|
+
return
|
|
307
|
+
|
|
308
|
+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
309
|
+
apply_fn_signature = inspect.signature(apply_fn)
|
|
310
|
+
|
|
311
|
+
# Filter out the keyword arguments that are not supported by the apply function
|
|
312
|
+
applicable_kwargs = {
|
|
313
|
+
key: value
|
|
314
|
+
for key, value in kwargs.items()
|
|
315
|
+
if key in apply_fn_signature.parameters
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
logger.info(
|
|
319
|
+
f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Apply the default combination of liger kernels available for the model
|
|
323
|
+
apply_fn(**applicable_kwargs)
|
|
@@ -5,12 +5,28 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class LigerRMSNorm(nn.Module):
|
|
8
|
-
def __init__(
|
|
8
|
+
def __init__(
|
|
9
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones"
|
|
10
|
+
):
|
|
9
11
|
super().__init__()
|
|
10
|
-
|
|
11
|
-
|
|
12
|
+
assert init_fn in [
|
|
13
|
+
"ones",
|
|
14
|
+
"zeros",
|
|
15
|
+
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
16
|
+
self.weight = nn.Parameter(
|
|
17
|
+
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
|
18
|
+
)
|
|
19
|
+
self.variance_epsilon, self.offset, self.casting_mode = (
|
|
20
|
+
eps,
|
|
21
|
+
offset,
|
|
22
|
+
casting_mode,
|
|
23
|
+
)
|
|
12
24
|
|
|
13
25
|
def forward(self, hidden_states):
|
|
14
26
|
return LigerRMSNormFunction.apply(
|
|
15
|
-
hidden_states,
|
|
27
|
+
hidden_states,
|
|
28
|
+
self.weight,
|
|
29
|
+
self.variance_epsilon,
|
|
30
|
+
self.offset,
|
|
31
|
+
self.casting_mode,
|
|
16
32
|
)
|
|
@@ -38,3 +38,27 @@ class LigerBlockSparseTop2MLP(nn.Module):
|
|
|
38
38
|
def forward(self, x):
|
|
39
39
|
|
|
40
40
|
return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class LigerPhi3SwiGLUMLP(nn.Module):
|
|
44
|
+
"""
|
|
45
|
+
Patch Phi3MLP to use LigerSiLUMulFunction
|
|
46
|
+
https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, config):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.config = config
|
|
52
|
+
self.hidden_size = config.hidden_size
|
|
53
|
+
self.intermediate_size = config.intermediate_size
|
|
54
|
+
self.gate_up_proj = nn.Linear(
|
|
55
|
+
self.hidden_size, 2 * self.intermediate_size, bias=False
|
|
56
|
+
)
|
|
57
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
58
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
59
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
60
|
+
|
|
61
|
+
def forward(self, x):
|
|
62
|
+
up_states = self.gate_up_proj(x)
|
|
63
|
+
gate, up_states = up_states.chunk(2, dim=-1)
|
|
64
|
+
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
|
|
@@ -1,45 +1,2 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
from liger_kernel.transformers.monkey_patch import (
|
|
4
|
-
apply_liger_kernel_to_gemma,
|
|
5
|
-
apply_liger_kernel_to_llama,
|
|
6
|
-
apply_liger_kernel_to_mistral,
|
|
7
|
-
apply_liger_kernel_to_mixtral,
|
|
8
|
-
)
|
|
9
|
-
|
|
10
|
-
logger = logging.getLogger(__name__)
|
|
11
|
-
|
|
12
|
-
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
13
|
-
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
14
|
-
"gemma": apply_liger_kernel_to_gemma,
|
|
15
|
-
"llama": apply_liger_kernel_to_llama,
|
|
16
|
-
"mistral": apply_liger_kernel_to_mistral,
|
|
17
|
-
"mixtral": apply_liger_kernel_to_mixtral,
|
|
18
|
-
}
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def _apply_liger_kernel(model_type: str = "", **kwargs) -> None:
|
|
22
|
-
"""
|
|
23
|
-
Applies Liger kernels based on the specified model type. The custom
|
|
24
|
-
kernels for the specified model type will be applied with the provided
|
|
25
|
-
keyword arguments, otherwise the default configuration will be used.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
- model_type: the model types as defined in transformers/models/auto/modeling_auto.py
|
|
29
|
-
and specified in the model's config.json
|
|
30
|
-
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
if not model_type:
|
|
34
|
-
logger.info("Model type was not provided. No Liger kernels will be applied.")
|
|
35
|
-
return
|
|
36
|
-
|
|
37
|
-
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
38
|
-
logger.info(
|
|
39
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
40
|
-
)
|
|
41
|
-
return
|
|
42
|
-
|
|
43
|
-
logger.info(f"Applying Liger kernels for model type: {model_type}.")
|
|
44
|
-
# Apply the default combination of liger kernels available for the model
|
|
45
|
-
MODEL_TYPE_TO_APPLY_LIGER_FN[model_type](**kwargs)
|
|
1
|
+
# To not break HF Trainer integration
|
|
2
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|