liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +3 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/grpo_loss.py +160 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +14 -32
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +5 -9
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +23 -12
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +3 -2
- liger_kernel/transformers/__init__.py +19 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +7 -9
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +28 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +9 -15
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/olmo2.py +124 -0
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +214 -144
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +49 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
- liger_kernel-0.5.4.dist-info/RECORD +74 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.2.dist-info/RECORD +0 -65
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import logging
|
|
3
|
+
|
|
3
4
|
from functools import partial
|
|
4
5
|
from typing import Callable
|
|
5
6
|
|
|
6
7
|
import transformers
|
|
8
|
+
|
|
7
9
|
from packaging import version
|
|
8
10
|
from transformers import PreTrainedModel
|
|
9
11
|
|
|
@@ -12,38 +14,24 @@ from liger_kernel.transformers.functional import liger_cross_entropy
|
|
|
12
14
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
13
15
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
14
16
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
15
|
-
from liger_kernel.transformers.model.gemma import
|
|
16
|
-
lce_forward_deprecated as gemma_lce_forward_deprecated,
|
|
17
|
-
)
|
|
17
|
+
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
18
18
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
19
|
-
from liger_kernel.transformers.model.gemma2 import
|
|
20
|
-
lce_forward_deprecated as gemma2_lce_forward_deprected,
|
|
21
|
-
)
|
|
19
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
22
20
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
23
|
-
from liger_kernel.transformers.model.llama import
|
|
24
|
-
lce_forward_deprecated as llama_lce_forward_deprecated,
|
|
25
|
-
)
|
|
21
|
+
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
26
22
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
|
27
23
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
28
|
-
from liger_kernel.transformers.model.mixtral import
|
|
29
|
-
lce_forward_deprecated as mixtral_lce_forward_deprecated,
|
|
30
|
-
)
|
|
24
|
+
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
31
25
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
32
|
-
from liger_kernel.transformers.model.phi3 import
|
|
33
|
-
lce_forward_deprecated as phi3_lce_forward_deprecated,
|
|
34
|
-
)
|
|
26
|
+
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
|
35
27
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
36
|
-
from liger_kernel.transformers.model.qwen2 import
|
|
37
|
-
lce_forward_deprecated as qwen2_lce_forward_deprecated,
|
|
38
|
-
)
|
|
28
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
39
29
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
40
30
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
41
31
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
42
|
-
from liger_kernel.transformers.swiglu import
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
LigerSwiGLUMLP,
|
|
46
|
-
)
|
|
32
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
33
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
34
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
47
35
|
|
|
48
36
|
transformer_version = version.parse(transformers.__version__)
|
|
49
37
|
|
|
@@ -57,28 +45,101 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
57
45
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
58
46
|
|
|
59
47
|
|
|
60
|
-
def _patch_rms_norm_module(
|
|
61
|
-
module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
|
|
62
|
-
):
|
|
48
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
|
63
49
|
module.offset = offset
|
|
64
50
|
module.casting_mode = casting_mode
|
|
65
|
-
module.variance_epsilon = (
|
|
66
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
67
|
-
)
|
|
51
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
68
52
|
module.in_place = in_place
|
|
69
53
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
70
54
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
71
55
|
|
|
72
56
|
|
|
73
57
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
74
|
-
module.variance_epsilon = (
|
|
75
|
-
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
76
|
-
)
|
|
58
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
77
59
|
module.hidden_size = module.normalized_shape
|
|
78
60
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
79
61
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
80
62
|
|
|
81
63
|
|
|
64
|
+
def apply_liger_kernel_to_granite(
|
|
65
|
+
rope: bool = True,
|
|
66
|
+
cross_entropy: bool = True,
|
|
67
|
+
fused_linear_cross_entropy: bool = False,
|
|
68
|
+
rms_norm: bool = True,
|
|
69
|
+
swiglu: bool = True,
|
|
70
|
+
model: PreTrainedModel = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
77
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
78
|
+
fused_linear_cross_entropy (bool):
|
|
79
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
80
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
81
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
82
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
83
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
84
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
85
|
+
loaded. Default is None.
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
Debugging notes:
|
|
90
|
+
If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
94
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
from transformers.models.granite import modeling_granite
|
|
98
|
+
from transformers.models.granite.modeling_granite import GraniteModel
|
|
99
|
+
|
|
100
|
+
if swiglu:
|
|
101
|
+
modeling_granite.GraniteMLP = LigerSwiGLUMLP
|
|
102
|
+
|
|
103
|
+
if rms_norm:
|
|
104
|
+
modeling_granite.GraniteRMSNorm = LigerRMSNorm
|
|
105
|
+
|
|
106
|
+
if rope:
|
|
107
|
+
modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
108
|
+
|
|
109
|
+
if cross_entropy:
|
|
110
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
111
|
+
from transformers.loss.loss_utils import nn
|
|
112
|
+
|
|
113
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
114
|
+
else:
|
|
115
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
116
|
+
modeling_granite.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
117
|
+
|
|
118
|
+
if fused_linear_cross_entropy:
|
|
119
|
+
raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
|
|
120
|
+
# NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
|
|
121
|
+
# call, so we can't sidestep logit materialization. A bit more work
|
|
122
|
+
# would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
|
|
123
|
+
# for the logit output.
|
|
124
|
+
|
|
125
|
+
if model is not None:
|
|
126
|
+
# The model instance already exists, so we need to additionally patch the
|
|
127
|
+
# instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
|
|
128
|
+
|
|
129
|
+
# get the base model from the model instance
|
|
130
|
+
base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
|
|
131
|
+
|
|
132
|
+
if rms_norm:
|
|
133
|
+
_patch_rms_norm_module(base_model.norm)
|
|
134
|
+
|
|
135
|
+
for decoder_layer in base_model.layers:
|
|
136
|
+
if swiglu:
|
|
137
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
138
|
+
if rms_norm:
|
|
139
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
140
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
141
|
+
|
|
142
|
+
|
|
82
143
|
def apply_liger_kernel_to_llama(
|
|
83
144
|
rope: bool = True,
|
|
84
145
|
cross_entropy: bool = False,
|
|
@@ -103,9 +164,9 @@ def apply_liger_kernel_to_llama(
|
|
|
103
164
|
loaded. Default is None.
|
|
104
165
|
"""
|
|
105
166
|
|
|
106
|
-
assert not (
|
|
107
|
-
cross_entropy and fused_linear_cross_entropy
|
|
108
|
-
)
|
|
167
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
168
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
169
|
+
)
|
|
109
170
|
|
|
110
171
|
from transformers.models.llama import modeling_llama
|
|
111
172
|
from transformers.models.llama.modeling_llama import LlamaModel
|
|
@@ -145,9 +206,7 @@ def apply_liger_kernel_to_llama(
|
|
|
145
206
|
|
|
146
207
|
for decoder_layer in base_model.layers:
|
|
147
208
|
if swiglu:
|
|
148
|
-
_bind_method_to_module(
|
|
149
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
150
|
-
)
|
|
209
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
151
210
|
if rms_norm:
|
|
152
211
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
153
212
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -179,22 +238,18 @@ def apply_liger_kernel_to_mllama(
|
|
|
179
238
|
loaded. Default is None.
|
|
180
239
|
"""
|
|
181
240
|
|
|
182
|
-
assert not (
|
|
183
|
-
cross_entropy and fused_linear_cross_entropy
|
|
184
|
-
)
|
|
241
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
242
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
243
|
+
)
|
|
185
244
|
|
|
186
245
|
from transformers.models.mllama import modeling_mllama
|
|
187
|
-
from transformers.models.mllama.modeling_mllama import
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
MllamaVisionModel,
|
|
192
|
-
)
|
|
246
|
+
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
|
|
247
|
+
from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
|
|
248
|
+
from transformers.models.mllama.modeling_mllama import MllamaTextModel
|
|
249
|
+
from transformers.models.mllama.modeling_mllama import MllamaVisionModel
|
|
193
250
|
|
|
194
251
|
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
|
|
195
|
-
from liger_kernel.transformers.model.mllama import
|
|
196
|
-
lce_forward_deprecated as mllama_lce_forward_deprecated,
|
|
197
|
-
)
|
|
252
|
+
from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
|
|
198
253
|
|
|
199
254
|
if rope:
|
|
200
255
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -241,9 +296,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
241
296
|
_patch_rms_norm_module(text_model.norm)
|
|
242
297
|
for decoder_layer in text_model.layers:
|
|
243
298
|
if swiglu:
|
|
244
|
-
_bind_method_to_module(
|
|
245
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
246
|
-
)
|
|
299
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
247
300
|
if rms_norm:
|
|
248
301
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
249
302
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -287,9 +340,9 @@ def apply_liger_kernel_to_mistral(
|
|
|
287
340
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
288
341
|
loaded. Default is None.
|
|
289
342
|
"""
|
|
290
|
-
assert not (
|
|
291
|
-
cross_entropy and fused_linear_cross_entropy
|
|
292
|
-
)
|
|
343
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
344
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
345
|
+
)
|
|
293
346
|
|
|
294
347
|
from transformers.models.mistral import modeling_mistral
|
|
295
348
|
from transformers.models.mistral.modeling_mistral import MistralModel
|
|
@@ -317,9 +370,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
317
370
|
|
|
318
371
|
for decoder_layer in base_model.layers:
|
|
319
372
|
if swiglu:
|
|
320
|
-
_bind_method_to_module(
|
|
321
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
322
|
-
)
|
|
373
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
323
374
|
if rms_norm:
|
|
324
375
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
325
376
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -349,9 +400,9 @@ def apply_liger_kernel_to_mixtral(
|
|
|
349
400
|
loaded. Default is None.
|
|
350
401
|
"""
|
|
351
402
|
|
|
352
|
-
assert not (
|
|
353
|
-
cross_entropy and fused_linear_cross_entropy
|
|
354
|
-
)
|
|
403
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
404
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
405
|
+
)
|
|
355
406
|
|
|
356
407
|
from transformers.models.mixtral import modeling_mixtral
|
|
357
408
|
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|
@@ -391,9 +442,7 @@ def apply_liger_kernel_to_mixtral(
|
|
|
391
442
|
for decoder_layer in base_model.layers:
|
|
392
443
|
if swiglu:
|
|
393
444
|
for expert in decoder_layer.block_sparse_moe.experts:
|
|
394
|
-
_bind_method_to_module(
|
|
395
|
-
expert, "forward", LigerBlockSparseTop2MLP.forward
|
|
396
|
-
)
|
|
445
|
+
_bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
|
|
397
446
|
if rms_norm:
|
|
398
447
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
399
448
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -423,20 +472,16 @@ def apply_liger_kernel_to_gemma(
|
|
|
423
472
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
424
473
|
loaded. Default is None.
|
|
425
474
|
"""
|
|
426
|
-
assert not (
|
|
427
|
-
cross_entropy and fused_linear_cross_entropy
|
|
428
|
-
)
|
|
475
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
476
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
477
|
+
)
|
|
429
478
|
|
|
430
479
|
from transformers.models.gemma import modeling_gemma
|
|
431
480
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
432
481
|
|
|
433
482
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
434
|
-
LigerRMSNormForGemma = partial(
|
|
435
|
-
|
|
436
|
-
)
|
|
437
|
-
_patch_rms_norm_module_for_gemma = partial(
|
|
438
|
-
_patch_rms_norm_module, casting_mode="gemma", offset=1.0
|
|
439
|
-
)
|
|
483
|
+
LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
|
|
484
|
+
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
|
440
485
|
|
|
441
486
|
if rope:
|
|
442
487
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -471,9 +516,7 @@ def apply_liger_kernel_to_gemma(
|
|
|
471
516
|
|
|
472
517
|
for decoder_layer in base_model.layers:
|
|
473
518
|
if geglu:
|
|
474
|
-
_bind_method_to_module(
|
|
475
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
|
476
|
-
)
|
|
519
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
|
477
520
|
if rms_norm:
|
|
478
521
|
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
|
479
522
|
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
|
@@ -503,16 +546,14 @@ def apply_liger_kernel_to_gemma2(
|
|
|
503
546
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
504
547
|
loaded. Default is None.
|
|
505
548
|
"""
|
|
506
|
-
assert not (
|
|
507
|
-
cross_entropy and fused_linear_cross_entropy
|
|
508
|
-
)
|
|
549
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
550
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
551
|
+
)
|
|
509
552
|
|
|
510
553
|
from transformers.models.gemma2 import modeling_gemma2
|
|
511
554
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
512
555
|
|
|
513
|
-
LigerRMSNormForGemma2 = partial(
|
|
514
|
-
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
|
|
515
|
-
)
|
|
556
|
+
LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
|
|
516
557
|
_patch_rms_norm_module_for_gemma2 = partial(
|
|
517
558
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
518
559
|
)
|
|
@@ -551,20 +592,12 @@ def apply_liger_kernel_to_gemma2(
|
|
|
551
592
|
|
|
552
593
|
for decoder_layer in base_model.layers:
|
|
553
594
|
if geglu:
|
|
554
|
-
_bind_method_to_module(
|
|
555
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
|
556
|
-
)
|
|
595
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
|
557
596
|
if rms_norm:
|
|
558
597
|
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
|
559
|
-
_patch_rms_norm_module_for_gemma2(
|
|
560
|
-
|
|
561
|
-
)
|
|
562
|
-
_patch_rms_norm_module_for_gemma2(
|
|
563
|
-
decoder_layer.pre_feedforward_layernorm
|
|
564
|
-
)
|
|
565
|
-
_patch_rms_norm_module_for_gemma2(
|
|
566
|
-
decoder_layer.post_feedforward_layernorm
|
|
567
|
-
)
|
|
598
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
|
599
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
|
|
600
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
|
568
601
|
|
|
569
602
|
|
|
570
603
|
def apply_liger_kernel_to_qwen2(
|
|
@@ -590,9 +623,9 @@ def apply_liger_kernel_to_qwen2(
|
|
|
590
623
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
591
624
|
loaded. Default is None.
|
|
592
625
|
"""
|
|
593
|
-
assert not (
|
|
594
|
-
cross_entropy and fused_linear_cross_entropy
|
|
595
|
-
)
|
|
626
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
627
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
628
|
+
)
|
|
596
629
|
|
|
597
630
|
from transformers.models.qwen2 import modeling_qwen2
|
|
598
631
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
|
@@ -633,9 +666,7 @@ def apply_liger_kernel_to_qwen2(
|
|
|
633
666
|
|
|
634
667
|
for decoder_layer in base_model.layers:
|
|
635
668
|
if swiglu:
|
|
636
|
-
_bind_method_to_module(
|
|
637
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
638
|
-
)
|
|
669
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
639
670
|
if rms_norm:
|
|
640
671
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
641
672
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -667,21 +698,17 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
667
698
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
668
699
|
loaded. Default is None.
|
|
669
700
|
"""
|
|
670
|
-
assert not (
|
|
671
|
-
cross_entropy and fused_linear_cross_entropy
|
|
672
|
-
)
|
|
701
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
702
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
703
|
+
)
|
|
673
704
|
|
|
674
705
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
675
706
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
676
707
|
|
|
677
|
-
from liger_kernel.transformers.model.qwen2_vl import
|
|
678
|
-
lce_forward as qwen2_vl_lce_forward,
|
|
679
|
-
)
|
|
708
|
+
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
|
680
709
|
|
|
681
710
|
if rope:
|
|
682
|
-
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb =
|
|
683
|
-
liger_multimodal_rotary_pos_emb
|
|
684
|
-
)
|
|
711
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
685
712
|
if rms_norm:
|
|
686
713
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
687
714
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
@@ -712,9 +739,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
712
739
|
_patch_rms_norm_module(base_model.norm)
|
|
713
740
|
for decoder_layer in base_model.layers:
|
|
714
741
|
if swiglu:
|
|
715
|
-
_bind_method_to_module(
|
|
716
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
717
|
-
)
|
|
742
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
718
743
|
if rms_norm:
|
|
719
744
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
720
745
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -743,9 +768,9 @@ def apply_liger_kernel_to_phi3(
|
|
|
743
768
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
744
769
|
loaded. Default is None.
|
|
745
770
|
"""
|
|
746
|
-
assert not (
|
|
747
|
-
cross_entropy and fused_linear_cross_entropy
|
|
748
|
-
)
|
|
771
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
772
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
773
|
+
)
|
|
749
774
|
|
|
750
775
|
from transformers.models.phi3 import modeling_phi3
|
|
751
776
|
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
@@ -783,23 +808,86 @@ def apply_liger_kernel_to_phi3(
|
|
|
783
808
|
|
|
784
809
|
for decoder_layer in base_model.layers:
|
|
785
810
|
if swiglu:
|
|
786
|
-
_bind_method_to_module(
|
|
787
|
-
decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
|
|
788
|
-
)
|
|
811
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
|
|
789
812
|
if rms_norm:
|
|
790
813
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
791
814
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
792
815
|
|
|
793
816
|
|
|
817
|
+
def apply_liger_kernel_to_olmo2(
|
|
818
|
+
rope: bool = True,
|
|
819
|
+
cross_entropy: bool = False,
|
|
820
|
+
fused_linear_cross_entropy: bool = True,
|
|
821
|
+
rms_norm: bool = True,
|
|
822
|
+
swiglu: bool = True,
|
|
823
|
+
model: PreTrainedModel = None,
|
|
824
|
+
) -> None:
|
|
825
|
+
"""
|
|
826
|
+
Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
|
|
827
|
+
|
|
828
|
+
Args:
|
|
829
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
830
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
831
|
+
fused_linear_cross_entropy (bool):
|
|
832
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
833
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
834
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
835
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
836
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
|
|
837
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
838
|
+
loaded. Default is None.
|
|
839
|
+
"""
|
|
840
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
841
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
from transformers.models.olmo2 import modeling_olmo2
|
|
845
|
+
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
846
|
+
|
|
847
|
+
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
848
|
+
|
|
849
|
+
if rope:
|
|
850
|
+
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
851
|
+
if rms_norm:
|
|
852
|
+
modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
|
|
853
|
+
if swiglu:
|
|
854
|
+
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
855
|
+
if cross_entropy:
|
|
856
|
+
from transformers.loss.loss_utils import nn
|
|
857
|
+
|
|
858
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
859
|
+
if fused_linear_cross_entropy:
|
|
860
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
861
|
+
|
|
862
|
+
if model is not None:
|
|
863
|
+
# The model instance already exists, so we need to additionally patch the
|
|
864
|
+
# instance variables that reference already-instantiated modules
|
|
865
|
+
|
|
866
|
+
# get the base model from the model instance
|
|
867
|
+
base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
|
|
868
|
+
|
|
869
|
+
if rms_norm:
|
|
870
|
+
_patch_rms_norm_module(base_model.norm)
|
|
871
|
+
|
|
872
|
+
for decoder_layer in base_model.layers:
|
|
873
|
+
if swiglu:
|
|
874
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
875
|
+
if rms_norm:
|
|
876
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
877
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
878
|
+
|
|
879
|
+
|
|
794
880
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
795
881
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
796
882
|
"gemma": apply_liger_kernel_to_gemma,
|
|
797
883
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
798
884
|
"llama": apply_liger_kernel_to_llama,
|
|
885
|
+
"granite": apply_liger_kernel_to_granite,
|
|
799
886
|
"mllama": apply_liger_kernel_to_mllama,
|
|
800
887
|
"mllama_text_model": apply_liger_kernel_to_mllama,
|
|
801
888
|
"mistral": apply_liger_kernel_to_mistral,
|
|
802
889
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
890
|
+
"olmo2": apply_liger_kernel_to_olmo2,
|
|
803
891
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
804
892
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
805
893
|
"phi3": apply_liger_kernel_to_phi3,
|
|
@@ -826,24 +914,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
|
|
|
826
914
|
return
|
|
827
915
|
|
|
828
916
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
829
|
-
logger.info(
|
|
830
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
831
|
-
)
|
|
917
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
|
832
918
|
return
|
|
833
919
|
|
|
834
920
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
835
921
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
836
922
|
|
|
837
923
|
# Filter out the keyword arguments that are not supported by the apply function
|
|
838
|
-
applicable_kwargs = {
|
|
839
|
-
key: value
|
|
840
|
-
for key, value in kwargs.items()
|
|
841
|
-
if key in apply_fn_signature.parameters
|
|
842
|
-
}
|
|
924
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
|
843
925
|
|
|
844
|
-
logger.info(
|
|
845
|
-
f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
|
|
846
|
-
)
|
|
926
|
+
logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
|
|
847
927
|
|
|
848
928
|
# Assume this is invoked pre-model initialization, so we only need to patch transformers code
|
|
849
929
|
apply_fn(**applicable_kwargs)
|
|
@@ -857,20 +937,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
|
857
937
|
- model: the model instance to apply Liger kernels to
|
|
858
938
|
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
859
939
|
"""
|
|
860
|
-
model_type = getattr(model, "config", None) and getattr(
|
|
861
|
-
model.config, "model_type", None
|
|
862
|
-
)
|
|
940
|
+
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
|
863
941
|
|
|
864
942
|
if not model_type:
|
|
865
|
-
logger.info(
|
|
866
|
-
"Model type could not be determined from model config. No Liger kernels will be applied."
|
|
867
|
-
)
|
|
943
|
+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
|
|
868
944
|
return
|
|
869
945
|
|
|
870
946
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
871
|
-
logger.info(
|
|
872
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
873
|
-
)
|
|
947
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
|
874
948
|
return
|
|
875
949
|
|
|
876
950
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
@@ -878,11 +952,7 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
|
878
952
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
879
953
|
|
|
880
954
|
# Filter out the keyword arguments that are not supported by the apply function
|
|
881
|
-
applicable_kwargs = {
|
|
882
|
-
key: value
|
|
883
|
-
for key, value in kwargs.items()
|
|
884
|
-
if key in apply_fn_signature.parameters
|
|
885
|
-
}
|
|
955
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
|
886
956
|
logger.info(
|
|
887
957
|
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
|
888
958
|
)
|
|
@@ -19,9 +19,7 @@ class LigerRMSNorm(nn.Module):
|
|
|
19
19
|
"ones",
|
|
20
20
|
"zeros",
|
|
21
21
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
22
|
-
self.weight = nn.Parameter(
|
|
23
|
-
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
|
24
|
-
)
|
|
22
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
|
25
23
|
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
|
|
26
24
|
eps,
|
|
27
25
|
offset,
|
|
@@ -40,4 +38,6 @@ class LigerRMSNorm(nn.Module):
|
|
|
40
38
|
)
|
|
41
39
|
|
|
42
40
|
def extra_repr(self):
|
|
43
|
-
return
|
|
41
|
+
return (
|
|
42
|
+
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
|
|
43
|
+
)
|
|
@@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
8
8
|
Args:
|
|
9
9
|
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
|
|
10
10
|
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
|
|
11
|
-
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
|
|
12
|
-
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
|
|
11
|
+
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
|
12
|
+
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
|
13
13
|
position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
|
|
14
14
|
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
|
|
15
15
|
|
|
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
|
|
|
16
16
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
17
17
|
|
|
18
18
|
def forward(self, x):
|
|
19
|
-
|
|
20
|
-
return self.down_proj(
|
|
21
|
-
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
|
22
|
-
)
|
|
19
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
class LigerBlockSparseTop2MLP(nn.Module):
|
|
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
|
|
|
36
33
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
37
34
|
|
|
38
35
|
def forward(self, x):
|
|
39
|
-
|
|
40
36
|
return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
|
|
41
37
|
|
|
42
38
|
|
|
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
|
51
47
|
self.config = config
|
|
52
48
|
self.hidden_size = config.hidden_size
|
|
53
49
|
self.intermediate_size = config.intermediate_size
|
|
54
|
-
self.gate_up_proj = nn.Linear(
|
|
55
|
-
self.hidden_size, 2 * self.intermediate_size, bias=False
|
|
56
|
-
)
|
|
50
|
+
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
|
57
51
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
58
52
|
if config.hidden_act not in ["silu", "swish"]:
|
|
59
53
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
@@ -1,6 +1,4 @@
|
|
|
1
1
|
try:
|
|
2
|
-
from liger_kernel.transformers.trainer.orpo_trainer import
|
|
3
|
-
LigerORPOTrainer,
|
|
4
|
-
)
|
|
2
|
+
from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
|
|
5
3
|
except ImportError:
|
|
6
4
|
raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
|