liger-kernel 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,56 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.group_norm import LigerGroupNormFunction
5
+
6
+
7
+ class LigerGroupNorm(nn.Module):
8
+ def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones"):
9
+ """
10
+ A Group Normalization layer.
11
+ Args:
12
+ num_channels (int): Number of channels in the input tensor.
13
+ num_groups (int): Number of groups to divide the channels into.
14
+ eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-6.
15
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``.
16
+ init_fn (str, optional): Initialization function for the learnable parameters. Default: "ones".
17
+ """
18
+ super().__init__()
19
+ assert init_fn in [
20
+ "ones",
21
+ "zeros",
22
+ ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
23
+
24
+ assert (
25
+ num_channels % num_groups == 0
26
+ ), f"Number of channels {num_channels} must be divisible by num_groups {num_groups}"
27
+ self.num_channels = num_channels
28
+ self.num_groups = num_groups
29
+ self.eps = eps
30
+ self.weight = nn.Parameter(
31
+ torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels)
32
+ )
33
+ self.bias = nn.Parameter(
34
+ torch.randn(num_channels) if bias else torch.zeros(num_channels)
35
+ )
36
+ self.variance_epsilon = eps
37
+
38
+ def forward(self, hidden_states):
39
+ # hidden_states: (batch_size, num_channels, *)
40
+ assert (
41
+ hidden_states.dim() >= 3
42
+ ), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
43
+ assert (
44
+ hidden_states.size(1) == self.num_channels
45
+ ), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
46
+ return LigerGroupNormFunction.apply(
47
+ hidden_states,
48
+ self.weight,
49
+ self.bias,
50
+ self.num_channels,
51
+ self.num_groups,
52
+ self.variance_epsilon,
53
+ )
54
+
55
+ def extra_repr(self):
56
+ return f"{self.hidden_size}, num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.eps}"
@@ -0,0 +1,277 @@
1
+ import logging
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch.nn import CrossEntropyLoss
6
+ from transformers.cache_utils import HybridCache
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.models.gemma2.modeling_gemma2 import (
9
+ _CONFIG_FOR_DOC,
10
+ GEMMA2_INPUTS_DOCSTRING,
11
+ )
12
+ from transformers.utils import (
13
+ add_start_docstrings_to_model_forward,
14
+ replace_return_docstrings,
15
+ )
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
18
+ LigerFusedLinearCrossEntropyLoss,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def lce_forward_deprecated(
25
+ self,
26
+ input_ids: torch.LongTensor = None,
27
+ attention_mask: Optional[torch.Tensor] = None,
28
+ position_ids: Optional[torch.LongTensor] = None,
29
+ past_key_values: Optional[HybridCache] = None,
30
+ inputs_embeds: Optional[torch.FloatTensor] = None,
31
+ labels: Optional[torch.LongTensor] = None,
32
+ use_cache: Optional[bool] = None,
33
+ output_attentions: Optional[bool] = None,
34
+ output_hidden_states: Optional[bool] = None,
35
+ return_dict: Optional[bool] = None,
36
+ cache_position: Optional[torch.LongTensor] = None,
37
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
38
+ r"""
39
+ Args:
40
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
41
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
42
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
43
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
44
+
45
+ Returns:
46
+
47
+ Example:
48
+
49
+ ```python
50
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
51
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
52
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
53
+ >>> prompt = "What is your favorite condiment?"
54
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "What is your favorite condiment?"
59
+ ```"""
60
+
61
+ if self.training and self.config._attn_implementation != "eager":
62
+ logger.warning_once(
63
+ "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
64
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
65
+ )
66
+ output_attentions = (
67
+ output_attentions
68
+ if output_attentions is not None
69
+ else self.config.output_attentions
70
+ )
71
+ output_hidden_states = (
72
+ output_hidden_states
73
+ if output_hidden_states is not None
74
+ else self.config.output_hidden_states
75
+ )
76
+ return_dict = (
77
+ return_dict if return_dict is not None else self.config.use_return_dict
78
+ )
79
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
80
+ outputs = self.model(
81
+ input_ids=input_ids,
82
+ attention_mask=attention_mask,
83
+ position_ids=position_ids,
84
+ past_key_values=past_key_values,
85
+ inputs_embeds=inputs_embeds,
86
+ use_cache=use_cache,
87
+ output_attentions=output_attentions,
88
+ output_hidden_states=output_hidden_states,
89
+ return_dict=return_dict,
90
+ cache_position=cache_position,
91
+ )
92
+
93
+ hidden_states = outputs[0]
94
+
95
+ loss = None
96
+ logits = None
97
+
98
+ if self.training and (labels is not None):
99
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
100
+ shift_labels = labels[..., 1:].contiguous()
101
+
102
+ # flatten
103
+
104
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105
+ shift_labels = shift_labels.view(-1)
106
+
107
+ lce = LigerFusedLinearCrossEntropyLoss(
108
+ softcap=self.config.final_logit_softcapping
109
+ )
110
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
+
112
+ else:
113
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
114
+ logits = self.lm_head(hidden_states)
115
+ if self.config.final_logit_softcapping is not None:
116
+ logits = logits / self.config.final_logit_softcapping
117
+ logits = torch.tanh(logits)
118
+ logits = logits * self.config.final_logit_softcapping
119
+
120
+ loss = None
121
+ if labels is not None:
122
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
123
+ logits = logits.float()
124
+ # Shift so that tokens < n predict n
125
+ shift_logits = logits[..., :-1, :].contiguous()
126
+ shift_labels = labels[..., 1:].contiguous()
127
+ # Flatten the tokens
128
+ loss_fct = CrossEntropyLoss()
129
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
130
+ shift_labels = shift_labels.view(-1)
131
+ # Enable model parallelism
132
+ shift_labels = shift_labels.to(shift_logits.device)
133
+ loss = loss_fct(shift_logits, shift_labels)
134
+
135
+ if not return_dict:
136
+ output = (logits,) + outputs[1:]
137
+ return (loss,) + output if loss is not None else output
138
+
139
+ return CausalLMOutputWithPast(
140
+ loss=loss,
141
+ logits=logits,
142
+ past_key_values=outputs.past_key_values,
143
+ hidden_states=outputs.hidden_states,
144
+ attentions=outputs.attentions,
145
+ )
146
+
147
+
148
+ @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
149
+ @replace_return_docstrings(
150
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
151
+ )
152
+ def lce_forward(
153
+ self,
154
+ input_ids: torch.LongTensor = None,
155
+ attention_mask: Optional[torch.Tensor] = None,
156
+ position_ids: Optional[torch.LongTensor] = None,
157
+ past_key_values: Optional[HybridCache] = None,
158
+ inputs_embeds: Optional[torch.FloatTensor] = None,
159
+ labels: Optional[torch.LongTensor] = None,
160
+ use_cache: Optional[bool] = None,
161
+ output_attentions: Optional[bool] = None,
162
+ output_hidden_states: Optional[bool] = None,
163
+ return_dict: Optional[bool] = None,
164
+ cache_position: Optional[torch.LongTensor] = None,
165
+ num_logits_to_keep: int = 0,
166
+ **loss_kwargs,
167
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
168
+ r"""
169
+ Args:
170
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
171
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
172
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
173
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
174
+
175
+ num_logits_to_keep (`int`, *optional*):
176
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
177
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
178
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
179
+
180
+ Returns:
181
+
182
+ Example:
183
+
184
+ ```python
185
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
186
+
187
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
188
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
189
+
190
+ >>> prompt = "What is your favorite condiment?"
191
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
192
+
193
+ >>> # Generate
194
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
195
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196
+ "What is your favorite condiment?"
197
+ ```"""
198
+
199
+ if self.training and self.config._attn_implementation != "eager":
200
+ logger.warning_once(
201
+ "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
202
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
203
+ )
204
+ output_attentions = (
205
+ output_attentions
206
+ if output_attentions is not None
207
+ else self.config.output_attentions
208
+ )
209
+ output_hidden_states = (
210
+ output_hidden_states
211
+ if output_hidden_states is not None
212
+ else self.config.output_hidden_states
213
+ )
214
+ return_dict = (
215
+ return_dict if return_dict is not None else self.config.use_return_dict
216
+ )
217
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
218
+ outputs = self.model(
219
+ input_ids=input_ids,
220
+ attention_mask=attention_mask,
221
+ position_ids=position_ids,
222
+ past_key_values=past_key_values,
223
+ inputs_embeds=inputs_embeds,
224
+ use_cache=use_cache,
225
+ output_attentions=output_attentions,
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ cache_position=cache_position,
229
+ )
230
+
231
+ hidden_states = outputs[0]
232
+
233
+ logits = None
234
+ loss = None
235
+ # if in training mode, don't materialize logits
236
+ if self.training and (labels is not None):
237
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
238
+
239
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
240
+ shift_labels = labels[..., 1:].contiguous()
241
+
242
+ # flatten tokens
243
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
244
+ shift_labels = shift_labels.view(-1)
245
+
246
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
247
+ lce = LigerFusedLinearCrossEntropyLoss(
248
+ softcap=self.config.final_logit_softcapping,
249
+ reduction=reduction,
250
+ )
251
+
252
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
253
+ if reduction == "sum":
254
+ loss /= loss_kwargs["num_items_in_batch"]
255
+
256
+ else: # if in inference mode materialize logits
257
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
258
+ if self.config.final_logit_softcapping is not None:
259
+ logits = logits / self.config.final_logit_softcapping
260
+ logits = torch.tanh(logits)
261
+ logits = logits * self.config.final_logit_softcapping
262
+
263
+ loss = None
264
+ if labels is not None:
265
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
266
+
267
+ if not return_dict:
268
+ output = (logits,) + outputs[1:]
269
+ return (loss,) + output if loss is not None else output
270
+
271
+ return CausalLMOutputWithPast(
272
+ loss=loss,
273
+ logits=logits,
274
+ past_key_values=outputs.past_key_values,
275
+ hidden_states=outputs.hidden_states,
276
+ attentions=outputs.attentions,
277
+ )
@@ -1,7 +1,9 @@
1
1
  from typing import List, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
+ from packaging import version
4
5
  from torch.nn import CrossEntropyLoss
6
+ from transformers import __version__ as transformers_version
5
7
  from transformers.models.qwen2_vl.modeling_qwen2_vl import (
6
8
  _CONFIG_FOR_DOC,
7
9
  QWEN2_VL_INPUTS_DOCSTRING,
@@ -80,8 +82,6 @@ def lce_forward(
80
82
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
81
83
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
82
84
  ```"""
83
- # FIXME: The code is outdated and not compatible with transformer >= 4.46.1
84
-
85
85
  output_attentions = (
86
86
  output_attentions
87
87
  if output_attentions is not None
@@ -100,27 +100,53 @@ def lce_forward(
100
100
  inputs_embeds = self.model.embed_tokens(input_ids)
101
101
  if pixel_values is not None:
102
102
  pixel_values = pixel_values.type(self.visual.get_dtype())
103
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(
104
- inputs_embeds.device
103
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
104
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
105
+ n_image_features = image_embeds.shape[0]
106
+ if n_image_tokens != n_image_features:
107
+ raise ValueError(
108
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
109
+ )
110
+ image_mask = (
111
+ (input_ids == self.config.image_token_id)
112
+ .unsqueeze(-1)
113
+ .expand_as(inputs_embeds)
114
+ .to(inputs_embeds.device)
105
115
  )
106
- image_mask = input_ids == self.config.image_token_id
107
- if self.training:
108
- inputs_embeds = inputs_embeds.clone()
109
- inputs_embeds[image_mask] = image_embeds
116
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
117
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
118
+
110
119
  if pixel_values_videos is not None:
111
120
  pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
112
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(
113
- inputs_embeds.device
121
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
122
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
123
+ n_video_features = video_embeds.shape[0]
124
+ if n_video_tokens != n_video_features:
125
+ raise ValueError(
126
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
127
+ )
128
+ video_mask = (
129
+ (input_ids == self.config.video_token_id)
130
+ .unsqueeze(-1)
131
+ .expand_as(inputs_embeds)
132
+ .to(inputs_embeds.device)
114
133
  )
115
- video_mask = input_ids == self.config.video_token_id
116
- inputs_embeds[video_mask] = video_embeds
134
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
135
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
136
+
117
137
  if attention_mask is not None:
118
138
  attention_mask = attention_mask.to(inputs_embeds.device)
119
- # The code is copied from https://github.com/huggingface/transformers/pull/33487
120
- if position_ids is None and input_ids is not None:
121
- position_ids, _ = self.get_rope_index(
122
- input_ids, image_grid_thw, video_grid_thw, attention_mask
123
- )
139
+
140
+ if version.parse(transformers_version) > version.parse("4.46.2"):
141
+ # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
142
+ # https://github.com/huggingface/transformers/issues/33401
143
+ # While correct, this breaks equivalence with past versions of Qwen2-VL from
144
+ # transformers and leads to failed tests or users noticing differences in results.
145
+ # TODO: remove above conditional when liger drops support for transformers<4.47.0
146
+ if position_ids is None and input_ids is not None:
147
+ position_ids, _ = self.get_rope_index(
148
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
149
+ )
124
150
 
125
151
  outputs = self.model(
126
152
  input_ids=None,