liger-kernel 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +2 -0
- liger_kernel/ops/cross_entropy.py +143 -30
- liger_kernel/ops/fused_linear_cross_entropy.py +19 -2
- liger_kernel/ops/group_norm.py +322 -0
- liger_kernel/ops/rms_norm.py +27 -6
- liger_kernel/transformers/cross_entropy.py +44 -12
- liger_kernel/transformers/functional.py +34 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
- liger_kernel/transformers/group_norm.py +56 -0
- liger_kernel/transformers/model/gemma2.py +277 -0
- liger_kernel/transformers/monkey_patch.py +101 -62
- liger_kernel/transformers/rms_norm.py +11 -3
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +5 -3
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/RECORD +18 -15
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/NOTICE +0 -0
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.nn import CrossEntropyLoss
|
|
6
|
+
from transformers.cache_utils import HybridCache
|
|
7
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
8
|
+
from transformers.models.gemma2.modeling_gemma2 import (
|
|
9
|
+
_CONFIG_FOR_DOC,
|
|
10
|
+
GEMMA2_INPUTS_DOCSTRING,
|
|
11
|
+
)
|
|
12
|
+
from transformers.utils import (
|
|
13
|
+
add_start_docstrings_to_model_forward,
|
|
14
|
+
replace_return_docstrings,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|
18
|
+
LigerFusedLinearCrossEntropyLoss,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def lce_forward_deprecated(
|
|
25
|
+
self,
|
|
26
|
+
input_ids: torch.LongTensor = None,
|
|
27
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
28
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
29
|
+
past_key_values: Optional[HybridCache] = None,
|
|
30
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
31
|
+
labels: Optional[torch.LongTensor] = None,
|
|
32
|
+
use_cache: Optional[bool] = None,
|
|
33
|
+
output_attentions: Optional[bool] = None,
|
|
34
|
+
output_hidden_states: Optional[bool] = None,
|
|
35
|
+
return_dict: Optional[bool] = None,
|
|
36
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
37
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
38
|
+
r"""
|
|
39
|
+
Args:
|
|
40
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
41
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
42
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
43
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
51
|
+
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
52
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
53
|
+
>>> prompt = "What is your favorite condiment?"
|
|
54
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
55
|
+
>>> # Generate
|
|
56
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
57
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
58
|
+
"What is your favorite condiment?"
|
|
59
|
+
```"""
|
|
60
|
+
|
|
61
|
+
if self.training and self.config._attn_implementation != "eager":
|
|
62
|
+
logger.warning_once(
|
|
63
|
+
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
|
64
|
+
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
|
65
|
+
)
|
|
66
|
+
output_attentions = (
|
|
67
|
+
output_attentions
|
|
68
|
+
if output_attentions is not None
|
|
69
|
+
else self.config.output_attentions
|
|
70
|
+
)
|
|
71
|
+
output_hidden_states = (
|
|
72
|
+
output_hidden_states
|
|
73
|
+
if output_hidden_states is not None
|
|
74
|
+
else self.config.output_hidden_states
|
|
75
|
+
)
|
|
76
|
+
return_dict = (
|
|
77
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
78
|
+
)
|
|
79
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
80
|
+
outputs = self.model(
|
|
81
|
+
input_ids=input_ids,
|
|
82
|
+
attention_mask=attention_mask,
|
|
83
|
+
position_ids=position_ids,
|
|
84
|
+
past_key_values=past_key_values,
|
|
85
|
+
inputs_embeds=inputs_embeds,
|
|
86
|
+
use_cache=use_cache,
|
|
87
|
+
output_attentions=output_attentions,
|
|
88
|
+
output_hidden_states=output_hidden_states,
|
|
89
|
+
return_dict=return_dict,
|
|
90
|
+
cache_position=cache_position,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
hidden_states = outputs[0]
|
|
94
|
+
|
|
95
|
+
loss = None
|
|
96
|
+
logits = None
|
|
97
|
+
|
|
98
|
+
if self.training and (labels is not None):
|
|
99
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
100
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
101
|
+
|
|
102
|
+
# flatten
|
|
103
|
+
|
|
104
|
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
105
|
+
shift_labels = shift_labels.view(-1)
|
|
106
|
+
|
|
107
|
+
lce = LigerFusedLinearCrossEntropyLoss(
|
|
108
|
+
softcap=self.config.final_logit_softcapping
|
|
109
|
+
)
|
|
110
|
+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
114
|
+
logits = self.lm_head(hidden_states)
|
|
115
|
+
if self.config.final_logit_softcapping is not None:
|
|
116
|
+
logits = logits / self.config.final_logit_softcapping
|
|
117
|
+
logits = torch.tanh(logits)
|
|
118
|
+
logits = logits * self.config.final_logit_softcapping
|
|
119
|
+
|
|
120
|
+
loss = None
|
|
121
|
+
if labels is not None:
|
|
122
|
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
123
|
+
logits = logits.float()
|
|
124
|
+
# Shift so that tokens < n predict n
|
|
125
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
126
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
127
|
+
# Flatten the tokens
|
|
128
|
+
loss_fct = CrossEntropyLoss()
|
|
129
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
130
|
+
shift_labels = shift_labels.view(-1)
|
|
131
|
+
# Enable model parallelism
|
|
132
|
+
shift_labels = shift_labels.to(shift_logits.device)
|
|
133
|
+
loss = loss_fct(shift_logits, shift_labels)
|
|
134
|
+
|
|
135
|
+
if not return_dict:
|
|
136
|
+
output = (logits,) + outputs[1:]
|
|
137
|
+
return (loss,) + output if loss is not None else output
|
|
138
|
+
|
|
139
|
+
return CausalLMOutputWithPast(
|
|
140
|
+
loss=loss,
|
|
141
|
+
logits=logits,
|
|
142
|
+
past_key_values=outputs.past_key_values,
|
|
143
|
+
hidden_states=outputs.hidden_states,
|
|
144
|
+
attentions=outputs.attentions,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
|
149
|
+
@replace_return_docstrings(
|
|
150
|
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
151
|
+
)
|
|
152
|
+
def lce_forward(
|
|
153
|
+
self,
|
|
154
|
+
input_ids: torch.LongTensor = None,
|
|
155
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
156
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
157
|
+
past_key_values: Optional[HybridCache] = None,
|
|
158
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
159
|
+
labels: Optional[torch.LongTensor] = None,
|
|
160
|
+
use_cache: Optional[bool] = None,
|
|
161
|
+
output_attentions: Optional[bool] = None,
|
|
162
|
+
output_hidden_states: Optional[bool] = None,
|
|
163
|
+
return_dict: Optional[bool] = None,
|
|
164
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
165
|
+
num_logits_to_keep: int = 0,
|
|
166
|
+
**loss_kwargs,
|
|
167
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
168
|
+
r"""
|
|
169
|
+
Args:
|
|
170
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
171
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
172
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
173
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
174
|
+
|
|
175
|
+
num_logits_to_keep (`int`, *optional*):
|
|
176
|
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
|
177
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
178
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
|
|
182
|
+
Example:
|
|
183
|
+
|
|
184
|
+
```python
|
|
185
|
+
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
186
|
+
|
|
187
|
+
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
188
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
189
|
+
|
|
190
|
+
>>> prompt = "What is your favorite condiment?"
|
|
191
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
192
|
+
|
|
193
|
+
>>> # Generate
|
|
194
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
195
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
196
|
+
"What is your favorite condiment?"
|
|
197
|
+
```"""
|
|
198
|
+
|
|
199
|
+
if self.training and self.config._attn_implementation != "eager":
|
|
200
|
+
logger.warning_once(
|
|
201
|
+
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
|
202
|
+
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
|
203
|
+
)
|
|
204
|
+
output_attentions = (
|
|
205
|
+
output_attentions
|
|
206
|
+
if output_attentions is not None
|
|
207
|
+
else self.config.output_attentions
|
|
208
|
+
)
|
|
209
|
+
output_hidden_states = (
|
|
210
|
+
output_hidden_states
|
|
211
|
+
if output_hidden_states is not None
|
|
212
|
+
else self.config.output_hidden_states
|
|
213
|
+
)
|
|
214
|
+
return_dict = (
|
|
215
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
216
|
+
)
|
|
217
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
218
|
+
outputs = self.model(
|
|
219
|
+
input_ids=input_ids,
|
|
220
|
+
attention_mask=attention_mask,
|
|
221
|
+
position_ids=position_ids,
|
|
222
|
+
past_key_values=past_key_values,
|
|
223
|
+
inputs_embeds=inputs_embeds,
|
|
224
|
+
use_cache=use_cache,
|
|
225
|
+
output_attentions=output_attentions,
|
|
226
|
+
output_hidden_states=output_hidden_states,
|
|
227
|
+
return_dict=return_dict,
|
|
228
|
+
cache_position=cache_position,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
hidden_states = outputs[0]
|
|
232
|
+
|
|
233
|
+
logits = None
|
|
234
|
+
loss = None
|
|
235
|
+
# if in training mode, don't materialize logits
|
|
236
|
+
if self.training and (labels is not None):
|
|
237
|
+
# We do the same thing as ForCausalLMLoss but using Liger FLCE
|
|
238
|
+
|
|
239
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
240
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
241
|
+
|
|
242
|
+
# flatten tokens
|
|
243
|
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
244
|
+
shift_labels = shift_labels.view(-1)
|
|
245
|
+
|
|
246
|
+
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
247
|
+
lce = LigerFusedLinearCrossEntropyLoss(
|
|
248
|
+
softcap=self.config.final_logit_softcapping,
|
|
249
|
+
reduction=reduction,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
253
|
+
if reduction == "sum":
|
|
254
|
+
loss /= loss_kwargs["num_items_in_batch"]
|
|
255
|
+
|
|
256
|
+
else: # if in inference mode materialize logits
|
|
257
|
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
258
|
+
if self.config.final_logit_softcapping is not None:
|
|
259
|
+
logits = logits / self.config.final_logit_softcapping
|
|
260
|
+
logits = torch.tanh(logits)
|
|
261
|
+
logits = logits * self.config.final_logit_softcapping
|
|
262
|
+
|
|
263
|
+
loss = None
|
|
264
|
+
if labels is not None:
|
|
265
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
266
|
+
|
|
267
|
+
if not return_dict:
|
|
268
|
+
output = (logits,) + outputs[1:]
|
|
269
|
+
return (loss,) + output if loss is not None else output
|
|
270
|
+
|
|
271
|
+
return CausalLMOutputWithPast(
|
|
272
|
+
loss=loss,
|
|
273
|
+
logits=logits,
|
|
274
|
+
past_key_values=outputs.past_key_values,
|
|
275
|
+
hidden_states=outputs.hidden_states,
|
|
276
|
+
attentions=outputs.attentions,
|
|
277
|
+
)
|
|
@@ -8,12 +8,17 @@ from packaging import version
|
|
|
8
8
|
from transformers import PreTrainedModel
|
|
9
9
|
|
|
10
10
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
11
|
+
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
11
12
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
12
13
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
13
14
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
14
15
|
from liger_kernel.transformers.model.gemma import (
|
|
15
16
|
lce_forward_deprecated as gemma_lce_forward_deprecated,
|
|
16
17
|
)
|
|
18
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
19
|
+
from liger_kernel.transformers.model.gemma2 import (
|
|
20
|
+
lce_forward_deprecated as gemma2_lce_forward_deprected,
|
|
21
|
+
)
|
|
17
22
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
18
23
|
from liger_kernel.transformers.model.llama import (
|
|
19
24
|
lce_forward_deprecated as llama_lce_forward_deprecated,
|
|
@@ -99,6 +104,7 @@ def apply_liger_kernel_to_llama(
|
|
|
99
104
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
100
105
|
|
|
101
106
|
from transformers.models.llama import modeling_llama
|
|
107
|
+
from transformers.models.llama.modeling_llama import LlamaModel
|
|
102
108
|
|
|
103
109
|
if rope:
|
|
104
110
|
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -106,8 +112,16 @@ def apply_liger_kernel_to_llama(
|
|
|
106
112
|
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
|
107
113
|
if swiglu:
|
|
108
114
|
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
|
115
|
+
|
|
109
116
|
if cross_entropy:
|
|
110
|
-
|
|
117
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
118
|
+
from transformers.loss.loss_utils import nn
|
|
119
|
+
|
|
120
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
121
|
+
else:
|
|
122
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
123
|
+
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
124
|
+
|
|
111
125
|
if fused_linear_cross_entropy:
|
|
112
126
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
113
127
|
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
@@ -119,15 +133,8 @@ def apply_liger_kernel_to_llama(
|
|
|
119
133
|
# The model instance already exists, so we need to additionally patch the
|
|
120
134
|
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
121
135
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
base_model = model.model
|
|
125
|
-
elif hasattr(model, "transformer"):
|
|
126
|
-
# LlamaForQuestionAnswering uses "transformer" instead of "model"
|
|
127
|
-
base_model = model.transformer
|
|
128
|
-
else:
|
|
129
|
-
# Direct LlamaModel
|
|
130
|
-
base_model = model
|
|
136
|
+
# get the base model from the model instance
|
|
137
|
+
base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
|
|
131
138
|
|
|
132
139
|
if rms_norm:
|
|
133
140
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -194,7 +201,13 @@ def apply_liger_kernel_to_mllama(
|
|
|
194
201
|
if swiglu:
|
|
195
202
|
modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
|
|
196
203
|
if cross_entropy:
|
|
197
|
-
|
|
204
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
205
|
+
from transformers.loss.loss_utils import nn
|
|
206
|
+
|
|
207
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
208
|
+
else:
|
|
209
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
210
|
+
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
198
211
|
if fused_linear_cross_entropy:
|
|
199
212
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
200
213
|
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
@@ -258,7 +271,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
258
271
|
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
|
|
259
272
|
|
|
260
273
|
Args:
|
|
261
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is
|
|
274
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
262
275
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
263
276
|
fused_linear_cross_entropy (bool):
|
|
264
277
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
@@ -275,6 +288,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
275
288
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
276
289
|
|
|
277
290
|
from transformers.models.mistral import modeling_mistral
|
|
291
|
+
from transformers.models.mistral.modeling_mistral import MistralModel
|
|
278
292
|
|
|
279
293
|
if rope:
|
|
280
294
|
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -291,12 +305,8 @@ def apply_liger_kernel_to_mistral(
|
|
|
291
305
|
# The model instance already exists, so we need to additionally patch the
|
|
292
306
|
# instance variables that reference already-instantiated modules
|
|
293
307
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
base_model = model.model
|
|
297
|
-
else:
|
|
298
|
-
# Direct MistralModel
|
|
299
|
-
base_model = model
|
|
308
|
+
# get the base model from the model instance
|
|
309
|
+
base_model: MistralModel = getattr(model, model.base_model_prefix, model)
|
|
300
310
|
|
|
301
311
|
if rms_norm:
|
|
302
312
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -340,13 +350,21 @@ def apply_liger_kernel_to_mixtral(
|
|
|
340
350
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
341
351
|
|
|
342
352
|
from transformers.models.mixtral import modeling_mixtral
|
|
353
|
+
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|
343
354
|
|
|
344
355
|
if rope:
|
|
345
356
|
modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
346
357
|
if rms_norm:
|
|
347
358
|
modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
|
|
348
359
|
if cross_entropy:
|
|
349
|
-
|
|
360
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
361
|
+
from transformers.loss.loss_utils import nn
|
|
362
|
+
|
|
363
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
364
|
+
else:
|
|
365
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
366
|
+
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
367
|
+
|
|
350
368
|
if fused_linear_cross_entropy:
|
|
351
369
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
352
370
|
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
@@ -360,12 +378,8 @@ def apply_liger_kernel_to_mixtral(
|
|
|
360
378
|
# The model instance already exists, so we need to additionally patch the
|
|
361
379
|
# instance variables that reference already-instantiated modules
|
|
362
380
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
base_model = model.model
|
|
366
|
-
else:
|
|
367
|
-
# Direct MixtralModel
|
|
368
|
-
base_model = model
|
|
381
|
+
# get the base model from the model instance
|
|
382
|
+
base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
|
|
369
383
|
|
|
370
384
|
if rms_norm:
|
|
371
385
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -410,6 +424,7 @@ def apply_liger_kernel_to_gemma(
|
|
|
410
424
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
411
425
|
|
|
412
426
|
from transformers.models.gemma import modeling_gemma
|
|
427
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
413
428
|
|
|
414
429
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
415
430
|
LigerRMSNormForGemma = partial(
|
|
@@ -424,7 +439,13 @@ def apply_liger_kernel_to_gemma(
|
|
|
424
439
|
if rms_norm:
|
|
425
440
|
modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
|
|
426
441
|
if cross_entropy:
|
|
427
|
-
|
|
442
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
443
|
+
from transformers.loss.loss_utils import nn
|
|
444
|
+
|
|
445
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
446
|
+
else:
|
|
447
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
448
|
+
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
428
449
|
if geglu:
|
|
429
450
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
430
451
|
if fused_linear_cross_entropy:
|
|
@@ -438,12 +459,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
438
459
|
# The model instance already exists, so we need to additionally patch the
|
|
439
460
|
# instance variables that reference already-instantiated modules
|
|
440
461
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
base_model = model.model
|
|
444
|
-
else:
|
|
445
|
-
# Direct GemmaModel
|
|
446
|
-
base_model = model
|
|
462
|
+
# get the base model from the model instance
|
|
463
|
+
base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
|
|
447
464
|
|
|
448
465
|
if rms_norm:
|
|
449
466
|
_patch_rms_norm_module_for_gemma(base_model.norm)
|
|
@@ -460,7 +477,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
460
477
|
|
|
461
478
|
def apply_liger_kernel_to_gemma2(
|
|
462
479
|
rope: bool = True,
|
|
463
|
-
cross_entropy: bool =
|
|
480
|
+
cross_entropy: bool = False,
|
|
481
|
+
fused_linear_cross_entropy: bool = True,
|
|
464
482
|
rms_norm: bool = True,
|
|
465
483
|
geglu: bool = True,
|
|
466
484
|
model: PreTrainedModel = None,
|
|
@@ -471,16 +489,25 @@ def apply_liger_kernel_to_gemma2(
|
|
|
471
489
|
|
|
472
490
|
Args:
|
|
473
491
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
474
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
492
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
493
|
+
fused_linear_cross_entropy (bool):
|
|
494
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
495
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
496
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
475
497
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
476
498
|
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
477
499
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
478
500
|
loaded. Default is None.
|
|
479
501
|
"""
|
|
502
|
+
assert not (
|
|
503
|
+
cross_entropy and fused_linear_cross_entropy
|
|
504
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
505
|
+
|
|
480
506
|
from transformers.models.gemma2 import modeling_gemma2
|
|
507
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
481
508
|
|
|
482
509
|
LigerRMSNormForGemma2 = partial(
|
|
483
|
-
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
|
|
510
|
+
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
|
|
484
511
|
)
|
|
485
512
|
_patch_rms_norm_module_for_gemma2 = partial(
|
|
486
513
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma"
|
|
@@ -492,7 +519,19 @@ def apply_liger_kernel_to_gemma2(
|
|
|
492
519
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
493
520
|
modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
|
|
494
521
|
if cross_entropy:
|
|
495
|
-
|
|
522
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
523
|
+
from transformers.loss.loss_utils import nn
|
|
524
|
+
|
|
525
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
526
|
+
else:
|
|
527
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
528
|
+
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
529
|
+
if fused_linear_cross_entropy:
|
|
530
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
531
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
532
|
+
else:
|
|
533
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
534
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
496
535
|
if geglu:
|
|
497
536
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
498
537
|
|
|
@@ -500,12 +539,8 @@ def apply_liger_kernel_to_gemma2(
|
|
|
500
539
|
# The model instance already exists, so we need to additionally patch the
|
|
501
540
|
# instance variables that reference already-instantiated modules
|
|
502
541
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
base_model = model.model
|
|
506
|
-
else:
|
|
507
|
-
# Direct Gemma2Model
|
|
508
|
-
base_model = model
|
|
542
|
+
# get the base model from the model instance
|
|
543
|
+
base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
|
|
509
544
|
|
|
510
545
|
if rms_norm:
|
|
511
546
|
_patch_rms_norm_module_for_gemma2(base_model.norm)
|
|
@@ -556,13 +591,21 @@ def apply_liger_kernel_to_qwen2(
|
|
|
556
591
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
557
592
|
|
|
558
593
|
from transformers.models.qwen2 import modeling_qwen2
|
|
594
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
|
559
595
|
|
|
560
596
|
if rope:
|
|
561
597
|
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
562
598
|
if rms_norm:
|
|
563
599
|
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
600
|
+
|
|
564
601
|
if cross_entropy:
|
|
565
|
-
|
|
602
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
603
|
+
from transformers.loss.loss_utils import nn
|
|
604
|
+
|
|
605
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
606
|
+
else:
|
|
607
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
608
|
+
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
566
609
|
|
|
567
610
|
# import pdb; pdb.set_trace()
|
|
568
611
|
if fused_linear_cross_entropy:
|
|
@@ -580,12 +623,8 @@ def apply_liger_kernel_to_qwen2(
|
|
|
580
623
|
# The model instance already exists, so we need to additionally patch the
|
|
581
624
|
# instance variables that reference already-instantiated modules
|
|
582
625
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
base_model = model.model
|
|
586
|
-
else:
|
|
587
|
-
# Direct Qwen2Model
|
|
588
|
-
base_model = model
|
|
626
|
+
# get the base model from the model instance
|
|
627
|
+
base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
|
|
589
628
|
|
|
590
629
|
if rms_norm:
|
|
591
630
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -630,6 +669,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
630
669
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
631
670
|
|
|
632
671
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
672
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
633
673
|
|
|
634
674
|
from liger_kernel.transformers.model.qwen2_vl import (
|
|
635
675
|
lce_forward as qwen2_vl_lce_forward,
|
|
@@ -653,12 +693,8 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
653
693
|
# The model instance already exists, so we need to additionally patch the
|
|
654
694
|
# instance variables that reference already-instantiated modules
|
|
655
695
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
base_model = model.model
|
|
659
|
-
else:
|
|
660
|
-
# Direct Qwen2VLModel
|
|
661
|
-
base_model = model
|
|
696
|
+
# get the base model from the model instance
|
|
697
|
+
base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
|
|
662
698
|
|
|
663
699
|
if hasattr(model, "visual"):
|
|
664
700
|
# Patch Qwen2VisionTransformerPretrainedModel
|
|
@@ -707,6 +743,7 @@ def apply_liger_kernel_to_phi3(
|
|
|
707
743
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
708
744
|
|
|
709
745
|
from transformers.models.phi3 import modeling_phi3
|
|
746
|
+
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
710
747
|
|
|
711
748
|
if rope:
|
|
712
749
|
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
@@ -715,7 +752,13 @@ def apply_liger_kernel_to_phi3(
|
|
|
715
752
|
if swiglu:
|
|
716
753
|
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
717
754
|
if cross_entropy:
|
|
718
|
-
|
|
755
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
756
|
+
from transformers.loss.loss_utils import nn
|
|
757
|
+
|
|
758
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
759
|
+
else:
|
|
760
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
761
|
+
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
719
762
|
if fused_linear_cross_entropy:
|
|
720
763
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
721
764
|
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
@@ -727,12 +770,8 @@ def apply_liger_kernel_to_phi3(
|
|
|
727
770
|
# The model instance already exists, so we need to additionally patch the
|
|
728
771
|
# instance variables that reference already-instantiated modules
|
|
729
772
|
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
base_model = model.model
|
|
733
|
-
else:
|
|
734
|
-
# Direct Phi3Model
|
|
735
|
-
base_model = model
|
|
773
|
+
# get the base model from the model instance
|
|
774
|
+
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
|
|
736
775
|
|
|
737
776
|
if rms_norm:
|
|
738
777
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -6,7 +6,13 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
|
6
6
|
|
|
7
7
|
class LigerRMSNorm(nn.Module):
|
|
8
8
|
def __init__(
|
|
9
|
-
self,
|
|
9
|
+
self,
|
|
10
|
+
hidden_size,
|
|
11
|
+
eps=1e-6,
|
|
12
|
+
offset=0.0,
|
|
13
|
+
casting_mode="llama",
|
|
14
|
+
init_fn="ones",
|
|
15
|
+
in_place=True,
|
|
10
16
|
):
|
|
11
17
|
super().__init__()
|
|
12
18
|
assert init_fn in [
|
|
@@ -16,10 +22,11 @@ class LigerRMSNorm(nn.Module):
|
|
|
16
22
|
self.weight = nn.Parameter(
|
|
17
23
|
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
|
18
24
|
)
|
|
19
|
-
self.variance_epsilon, self.offset, self.casting_mode = (
|
|
25
|
+
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
|
|
20
26
|
eps,
|
|
21
27
|
offset,
|
|
22
28
|
casting_mode,
|
|
29
|
+
in_place,
|
|
23
30
|
)
|
|
24
31
|
|
|
25
32
|
def forward(self, hidden_states):
|
|
@@ -29,7 +36,8 @@ class LigerRMSNorm(nn.Module):
|
|
|
29
36
|
self.variance_epsilon,
|
|
30
37
|
self.offset,
|
|
31
38
|
self.casting_mode,
|
|
39
|
+
self.in_place,
|
|
32
40
|
)
|
|
33
41
|
|
|
34
42
|
def extra_repr(self):
|
|
35
|
-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}"
|
|
43
|
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
|