liger-kernel 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. liger_kernel/ops/cross_entropy.py +5 -39
  2. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  3. liger_kernel/ops/fused_linear_cross_entropy.py +12 -9
  4. liger_kernel/ops/fused_linear_jsd.py +245 -0
  5. liger_kernel/ops/geglu.py +2 -2
  6. liger_kernel/ops/jsd.py +176 -0
  7. liger_kernel/ops/kl_div.py +2 -2
  8. liger_kernel/ops/rms_norm.py +67 -42
  9. liger_kernel/ops/swiglu.py +2 -2
  10. liger_kernel/ops/utils.py +62 -1
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/functional.py +4 -0
  13. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  14. liger_kernel/transformers/jsd.py +75 -0
  15. liger_kernel/transformers/model/gemma.py +124 -1
  16. liger_kernel/transformers/model/llama.py +135 -4
  17. liger_kernel/transformers/model/mistral.py +3 -0
  18. liger_kernel/transformers/model/mixtral.py +153 -2
  19. liger_kernel/transformers/model/mllama.py +274 -0
  20. liger_kernel/transformers/model/phi3.py +140 -2
  21. liger_kernel/transformers/model/qwen2.py +123 -2
  22. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  23. liger_kernel/transformers/monkey_patch.py +158 -7
  24. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/METADATA +60 -28
  25. liger_kernel-0.4.0.dist-info/NOTICE +58 -0
  26. liger_kernel-0.4.0.dist-info/RECORD +48 -0
  27. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/WHEEL +1 -1
  28. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  29. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  30. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/LICENSE +0 -0
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
21
21
  @replace_return_docstrings(
22
22
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
23
  )
24
- def lce_forward(
24
+ def lce_forward_deprecated(
25
25
  self,
26
26
  input_ids: torch.LongTensor = None,
27
27
  attention_mask: Optional[torch.Tensor] = None,
@@ -109,8 +109,9 @@ def lce_forward(
109
109
 
110
110
  else:
111
111
  logits = self.lm_head(hidden_states)
112
- logits = logits.float()
113
112
  if labels is not None:
113
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
114
+ logits = logits.float()
114
115
  # Shift so that tokens < n predict n
115
116
  shift_logits = logits[..., :-1, :].contiguous()
116
117
  shift_labels = labels[..., 1:].contiguous()
@@ -133,3 +134,123 @@ def lce_forward(
133
134
  hidden_states=outputs.hidden_states,
134
135
  attentions=outputs.attentions,
135
136
  )
137
+
138
+
139
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
140
+ @replace_return_docstrings(
141
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
142
+ )
143
+ def lce_forward(
144
+ self,
145
+ input_ids: torch.LongTensor = None,
146
+ attention_mask: Optional[torch.Tensor] = None,
147
+ position_ids: Optional[torch.LongTensor] = None,
148
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
149
+ inputs_embeds: Optional[torch.FloatTensor] = None,
150
+ labels: Optional[torch.LongTensor] = None,
151
+ use_cache: Optional[bool] = None,
152
+ output_attentions: Optional[bool] = None,
153
+ output_hidden_states: Optional[bool] = None,
154
+ return_dict: Optional[bool] = None,
155
+ cache_position: Optional[torch.LongTensor] = None,
156
+ num_logits_to_keep: int = 0,
157
+ **loss_kwargs,
158
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
159
+ r"""
160
+ Args:
161
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
162
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
163
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
164
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
165
+
166
+ num_logits_to_keep (`int`, *optional*):
167
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
168
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
169
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
170
+
171
+ Returns:
172
+
173
+ Example:
174
+
175
+ ```python
176
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
177
+
178
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
179
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
180
+
181
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
182
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
183
+
184
+ >>> # Generate
185
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
186
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
187
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
188
+ ```"""
189
+
190
+ output_attentions = (
191
+ output_attentions
192
+ if output_attentions is not None
193
+ else self.config.output_attentions
194
+ )
195
+ output_hidden_states = (
196
+ output_hidden_states
197
+ if output_hidden_states is not None
198
+ else self.config.output_hidden_states
199
+ )
200
+ return_dict = (
201
+ return_dict if return_dict is not None else self.config.use_return_dict
202
+ )
203
+
204
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
205
+ outputs = self.model(
206
+ input_ids=input_ids,
207
+ attention_mask=attention_mask,
208
+ position_ids=position_ids,
209
+ past_key_values=past_key_values,
210
+ inputs_embeds=inputs_embeds,
211
+ use_cache=use_cache,
212
+ output_attentions=output_attentions,
213
+ output_hidden_states=output_hidden_states,
214
+ return_dict=return_dict,
215
+ cache_position=cache_position,
216
+ )
217
+
218
+ hidden_states = outputs[0]
219
+
220
+ logits = None
221
+ loss = None
222
+ # if in training mode, don't materialize logits
223
+ if self.training and (labels is not None):
224
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
225
+
226
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
227
+ shift_labels = labels[..., 1:].contiguous()
228
+
229
+ # flatten tokens
230
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
231
+ shift_labels = shift_labels.view(-1)
232
+
233
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
234
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
235
+
236
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
237
+ if reduction == "sum":
238
+ loss /= loss_kwargs["num_items_in_batch"]
239
+
240
+ else: # if in inference mode materialize logits
241
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
242
+ if labels is not None:
243
+ loss = self.loss_function(
244
+ logits=logits,
245
+ labels=labels,
246
+ vocab_size=self.config.vocab_size,
247
+ **loss_kwargs,
248
+ )
249
+
250
+ return CausalLMOutputWithPast(
251
+ loss=loss,
252
+ logits=logits,
253
+ past_key_values=outputs.past_key_values,
254
+ hidden_states=outputs.hidden_states,
255
+ attentions=outputs.attentions,
256
+ )
@@ -80,6 +80,7 @@ def lce_forward(
80
80
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
81
81
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
82
82
  ```"""
83
+ # FIXME: The code is outdated and not compatible with transformer >= 4.46.1
83
84
 
84
85
  output_attentions = (
85
86
  output_attentions
@@ -115,6 +116,11 @@ def lce_forward(
115
116
  inputs_embeds[video_mask] = video_embeds
116
117
  if attention_mask is not None:
117
118
  attention_mask = attention_mask.to(inputs_embeds.device)
119
+ # The code is copied from https://github.com/huggingface/transformers/pull/33487
120
+ if position_ids is None and input_ids is not None:
121
+ position_ids, _ = self.get_rope_index(
122
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
123
+ )
118
124
 
119
125
  outputs = self.model(
120
126
  input_ids=None,
@@ -145,8 +151,9 @@ def lce_forward(
145
151
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
146
152
  else:
147
153
  logits = self.lm_head(hidden_states)
148
- logits = logits.float()
149
154
  if labels is not None:
155
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
156
+ logits = logits.float()
150
157
  # Shift so that tokens < n predict n
151
158
  shift_logits = logits[..., :-1, :].contiguous()
152
159
  shift_labels = labels[..., 1:].contiguous()
@@ -3,17 +3,34 @@ import logging
3
3
  from functools import partial
4
4
  from typing import Callable
5
5
 
6
+ import transformers
7
+ from packaging import version
6
8
  from transformers import PreTrainedModel
7
9
 
8
10
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
9
11
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
10
12
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
11
13
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
14
+ from liger_kernel.transformers.model.gemma import (
15
+ lce_forward_deprecated as gemma_lce_forward_deprecated,
16
+ )
12
17
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
18
+ from liger_kernel.transformers.model.llama import (
19
+ lce_forward_deprecated as llama_lce_forward_deprecated,
20
+ )
13
21
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
14
22
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
23
+ from liger_kernel.transformers.model.mixtral import (
24
+ lce_forward_deprecated as mixtral_lce_forward_deprecated,
25
+ )
15
26
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
27
+ from liger_kernel.transformers.model.phi3 import (
28
+ lce_forward_deprecated as phi3_lce_forward_deprecated,
29
+ )
16
30
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
31
+ from liger_kernel.transformers.model.qwen2 import (
32
+ lce_forward_deprecated as qwen2_lce_forward_deprecated,
33
+ )
17
34
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
18
35
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
19
36
  from liger_kernel.transformers.swiglu import (
@@ -22,7 +39,11 @@ from liger_kernel.transformers.swiglu import (
22
39
  LigerSwiGLUMLP,
23
40
  )
24
41
 
42
+ transformer_version = version.parse(transformers.__version__)
43
+
25
44
  logger = logging.getLogger(__name__)
45
+ SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
46
+ TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
26
47
 
27
48
 
28
49
  def _bind_method_to_module(module, method_name: str, new_method: Callable):
@@ -88,7 +109,11 @@ def apply_liger_kernel_to_llama(
88
109
  if cross_entropy:
89
110
  modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
90
111
  if fused_linear_cross_entropy:
91
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
112
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
113
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
114
+ else: # if version < 4.46.1
115
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
116
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
92
117
 
93
118
  if model is not None:
94
119
  # The model instance already exists, so we need to additionally patch the
@@ -117,6 +142,110 @@ def apply_liger_kernel_to_llama(
117
142
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
118
143
 
119
144
 
145
+ def apply_liger_kernel_to_mllama(
146
+ rope: bool = True,
147
+ cross_entropy: bool = False,
148
+ fused_linear_cross_entropy: bool = True,
149
+ layer_norm: bool = True,
150
+ rms_norm: bool = True,
151
+ swiglu: bool = True,
152
+ model: PreTrainedModel = None,
153
+ ) -> None:
154
+ """
155
+ Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
156
+ NOTE: MLlama is not available in transformers<4.45.0
157
+
158
+ Args:
159
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
160
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
161
+ fused_linear_cross_entropy (bool):
162
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
163
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
164
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
165
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
166
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
167
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
168
+ loaded. Default is None.
169
+ """
170
+
171
+ assert not (
172
+ cross_entropy and fused_linear_cross_entropy
173
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
174
+
175
+ from transformers.models.mllama import modeling_mllama
176
+ from transformers.models.mllama.modeling_mllama import (
177
+ MllamaForCausalLM,
178
+ MllamaForConditionalGeneration,
179
+ MllamaTextModel,
180
+ MllamaVisionModel,
181
+ )
182
+
183
+ from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
184
+ from liger_kernel.transformers.model.mllama import (
185
+ lce_forward_deprecated as mllama_lce_forward_deprecated,
186
+ )
187
+
188
+ if rope:
189
+ modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
190
+ if layer_norm:
191
+ modeling_mllama.nn.LayerNorm = LigerLayerNorm
192
+ if rms_norm:
193
+ modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
194
+ if swiglu:
195
+ modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
196
+ if cross_entropy:
197
+ modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
198
+ if fused_linear_cross_entropy:
199
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
200
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
201
+ else: # if version < 4.46.1
202
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
203
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
204
+
205
+ if model is not None:
206
+ # The model instance already exists, so we need to additionally patch the
207
+ # instance variables that reference already-instantiated modules
208
+
209
+ if isinstance(model, MllamaForConditionalGeneration):
210
+ language_model: MllamaForCausalLM = model.language_model
211
+ vision_model: MllamaVisionModel = model.vision_model
212
+ text_model: MllamaTextModel = language_model.model
213
+ elif isinstance(model, MllamaForCausalLM):
214
+ text_model = model.model
215
+ vision_model = None
216
+ elif isinstance(model, MllamaTextModel):
217
+ text_model = model
218
+ vision_model = None
219
+ else:
220
+ raise ValueError(f"Unsupported Mllama model type: {type(model)}")
221
+
222
+ if text_model:
223
+ if rms_norm:
224
+ _patch_rms_norm_module(text_model.norm)
225
+ for decoder_layer in text_model.layers:
226
+ if swiglu:
227
+ _bind_method_to_module(
228
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
229
+ )
230
+ if rms_norm:
231
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
232
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
233
+
234
+ if vision_model:
235
+ _patch_layer_norm_module(vision_model.layernorm_pre)
236
+ _patch_layer_norm_module(vision_model.layernorm_post)
237
+
238
+ for layer in vision_model.transformer.layers:
239
+ if layer_norm:
240
+ _patch_layer_norm_module(layer.input_layernorm)
241
+ _patch_layer_norm_module(layer.post_attention_layernorm)
242
+
243
+ for layer in vision_model.global_transformer.layers:
244
+ if layer_norm:
245
+ _patch_layer_norm_module(layer.input_layernorm)
246
+ _patch_layer_norm_module(layer.post_attention_layernorm)
247
+
248
+
120
249
  def apply_liger_kernel_to_mistral(
121
250
  rope: bool = True,
122
251
  cross_entropy: bool = False,
@@ -219,7 +348,11 @@ def apply_liger_kernel_to_mixtral(
219
348
  if cross_entropy:
220
349
  modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
221
350
  if fused_linear_cross_entropy:
222
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
351
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
352
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
353
+ else: # if version < 4.46.1
354
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
355
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
223
356
  if swiglu:
224
357
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
225
358
 
@@ -295,7 +428,11 @@ def apply_liger_kernel_to_gemma(
295
428
  if geglu:
296
429
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
297
430
  if fused_linear_cross_entropy:
298
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
431
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
432
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
433
+ else: # if version < 4.46.1
434
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
435
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
299
436
 
300
437
  if model is not None:
301
438
  # The model instance already exists, so we need to additionally patch the
@@ -426,8 +563,16 @@ def apply_liger_kernel_to_qwen2(
426
563
  modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
427
564
  if cross_entropy:
428
565
  modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
566
+
567
+ # import pdb; pdb.set_trace()
429
568
  if fused_linear_cross_entropy:
430
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
569
+
570
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
571
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
572
+ else: # if version < 4.46.1
573
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
574
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
575
+
431
576
  if swiglu:
432
577
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
433
578
 
@@ -453,6 +598,7 @@ def apply_liger_kernel_to_qwen2(
453
598
  if rms_norm:
454
599
  _patch_rms_norm_module(decoder_layer.input_layernorm)
455
600
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
601
+ print("Applied Liger kernels to Qwen2")
456
602
 
457
603
 
458
604
  def apply_liger_kernel_to_qwen2_vl(
@@ -465,7 +611,7 @@ def apply_liger_kernel_to_qwen2_vl(
465
611
  ) -> None:
466
612
  """
467
613
  Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
468
- NOTE: Qwen2-VL is not available in transformers<=4.44.2
614
+ NOTE: Qwen2-VL is not available in transformers<4.45.0
469
615
 
470
616
  Args:
471
617
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
@@ -571,7 +717,11 @@ def apply_liger_kernel_to_phi3(
571
717
  if cross_entropy:
572
718
  modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
573
719
  if fused_linear_cross_entropy:
574
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
720
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
721
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
722
+ else: # if version < 4.46.1
723
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
724
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
575
725
 
576
726
  if model is not None:
577
727
  # The model instance already exists, so we need to additionally patch the
@@ -602,6 +752,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
602
752
  "gemma": apply_liger_kernel_to_gemma,
603
753
  "gemma2": apply_liger_kernel_to_gemma2,
604
754
  "llama": apply_liger_kernel_to_llama,
755
+ "mllama": apply_liger_kernel_to_mllama,
756
+ "mllama_text_model": apply_liger_kernel_to_mllama,
605
757
  "mistral": apply_liger_kernel_to_mistral,
606
758
  "mixtral": apply_liger_kernel_to_mixtral,
607
759
  "qwen2": apply_liger_kernel_to_qwen2,
@@ -687,7 +839,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
687
839
  for key, value in kwargs.items()
688
840
  if key in apply_fn_signature.parameters
689
841
  }
690
-
691
842
  logger.info(
692
843
  f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
693
844
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.3.1
3
+ Version: 0.4.0
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -31,7 +31,7 @@ Description-Content-Type: text/markdown
31
31
  License-File: LICENSE
32
32
  License-File: NOTICE
33
33
  Requires-Dist: torch>=2.1.2
34
- Requires-Dist: triton>=2.3.0
34
+ Requires-Dist: triton>=2.3.1
35
35
  Provides-Extra: dev
36
36
  Requires-Dist: transformers>=4.44.2; extra == "dev"
37
37
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
@@ -40,10 +40,13 @@ Requires-Dist: black>=24.4.2; extra == "dev"
40
40
  Requires-Dist: isort>=5.13.2; extra == "dev"
41
41
  Requires-Dist: pytest>=7.1.2; extra == "dev"
42
42
  Requires-Dist: datasets>=2.19.2; extra == "dev"
43
+ Requires-Dist: torchvision>=0.16.2; extra == "dev"
43
44
  Requires-Dist: seaborn; extra == "dev"
44
45
  Provides-Extra: transformers
45
46
  Requires-Dist: transformers~=4.0; extra == "transformers"
46
47
 
48
+ <a name="readme-top"></a>
49
+
47
50
  # Liger Kernel: Efficient Triton Kernels for LLM Training
48
51
 
49
52
 
@@ -52,6 +55,7 @@ Requires-Dist: transformers~=4.0; extra == "transformers"
52
55
  <th style="padding: 10px;" colspan="2">Stable</th>
53
56
  <th style="padding: 10px;" colspan="2">Nightly</th>
54
57
  <th style="padding: 10px;">Discord</th>
58
+ <th style="padding: 10px;">Gurubase (experimental)</th>
55
59
  </tr>
56
60
  <tr>
57
61
  <td style="padding: 10px;">
@@ -79,6 +83,11 @@ Requires-Dist: transformers~=4.0; extra == "transformers"
79
83
  <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
80
84
  </a>
81
85
  </td>
86
+ <td style="padding: 10px;">
87
+ <a href="https://gurubase.io/g/liger-kernel">
88
+ <img src="https://img.shields.io/badge/Gurubase-Ask%20Liger%20Kernel%20Guru-006BFF" alt="Ask Liger Kernel Guru">
89
+ </a>
90
+ </td>
82
91
  </tr>
83
92
  </table>
84
93
 
@@ -86,11 +95,12 @@ Requires-Dist: transformers~=4.0; extra == "transformers"
86
95
 
87
96
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
88
97
 
89
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing) | [Acknowledgement](#acknowledgement)
98
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
90
99
 
91
100
  <details>
92
101
  <summary>Latest News 🔥</summary>
93
-
102
+
103
+ - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
94
104
  - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
95
105
  - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
96
106
  - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
@@ -148,11 +158,18 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
148
158
 
149
159
  ## Installation
150
160
 
151
- ### Dependencies
161
+ ### Dependencies
162
+
163
+ #### CUDA
152
164
 
153
165
  - `torch >= 2.1.2`
154
166
  - `triton >= 2.3.0`
155
167
 
168
+ #### ROCm
169
+
170
+ - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
171
+ - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
172
+
156
173
  ### Optional Dependencies
157
174
 
158
175
  - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
@@ -182,6 +199,7 @@ pip install -e .
182
199
  pip install -e .[transformers]
183
200
  ```
184
201
 
202
+
185
203
  ## Getting Started
186
204
 
187
205
  There are a couple of ways to apply Liger kernels, depending on the level of customization required.
@@ -274,6 +292,7 @@ loss.backward()
274
292
  | **Model** | **API** | **Supported Operations** |
275
293
  |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
276
294
  | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
295
+ | LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
277
296
  | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
278
297
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
279
298
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -296,6 +315,8 @@ loss.backward()
296
315
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
297
316
  | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
298
317
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
318
+ | JSD | `liger_kernel.transformers.LigerJSD` |
319
+ | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
299
320
 
300
321
  - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
301
322
  - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
@@ -310,35 +331,23 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
310
331
  <!-- TODO: verify vocab sizes are accurate -->
311
332
  - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
312
333
  - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
334
+ - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
335
+ - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
336
+
313
337
 
314
338
  ### Experimental Kernels
315
339
 
316
340
  | **Kernel** | **API** |
317
341
  |---------------------------------|-------------------------------------------------------------|
318
342
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
319
-
343
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
320
344
 
321
345
  - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
322
-
346
+ - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
323
347
  <!-- TODO: be more specific about batch size -->
324
348
  > **Note:**
325
349
  > Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.
326
350
 
327
- ## Note on ML Compiler
328
-
329
- ### Torch Compile
330
-
331
- Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.
332
-
333
- | Configuration | Throughput (tokens/sec) | Memory Reserved (GB) |
334
- |--------------------------------|----------------------------|-------------------------|
335
- | Torch Compile | 3780 | 66.4 |
336
- | Torch Compile + Liger Kernel | 3702 | 31.0 |
337
-
338
- > **Note:**
339
- > 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
340
- > 2. Tested on torch `2.5.0.dev20240731+cu118`
341
-
342
351
  ## Contributing
343
352
 
344
353
  [CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
@@ -372,7 +381,14 @@ Many thanks to the contributors to these projects for their invaluable work that
372
381
 
373
382
  ## License
374
383
 
375
- [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE)
384
+ This project is licensed under the [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE) License (see `LICENSE` for details).
385
+ It also includes components from projects licensed under:
386
+
387
+ - Apache License 2.0 (see `LICENSE-APACHE-2.0` for details).
388
+ - MIT License (see `LICENSE-MIT-AutoAWQ` for details).
389
+ - MIT License (see `LICENSE-MIT-Efficient Cross Entropy` for details).
390
+ - MIT License (see `LICENSE-MIT-llmc` for details).
391
+ - MIT License (see `LICENSE-MIT-triton` for details).
376
392
 
377
393
  ## Contact
378
394
 
@@ -383,13 +399,29 @@ Many thanks to the contributors to these projects for their invaluable work that
383
399
 
384
400
  Biblatex entry:
385
401
  ```bib
386
- @software{liger2024,
387
- title = {Liger-Kernel: Efficient Triton Kernels for LLM Training},
388
- author = {Hsu, Pin-Lun and Dai, Yun and Kothapalli, Vignesh and Song, Qingquan and Tang, Shao and Zhu, Siyu},
389
- url = {https://github.com/linkedin/Liger-Kernel},
390
- year = {2024}
402
+ @article{hsu2024ligerkernelefficienttriton,
403
+ title={Liger Kernel: Efficient Triton Kernels for LLM Training},
404
+ author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
405
+ year={2024},
406
+ eprint={2410.10989},
407
+ archivePrefix={arXiv},
408
+ primaryClass={cs.LG},
409
+ url={https://arxiv.org/abs/2410.10989},
410
+ journal={arXiv preprint arXiv:2410.10989},
391
411
  }
392
412
  ```
393
413
 
394
414
  ## Star History
395
415
  [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
416
+
417
+ ## Contributors
418
+
419
+ <a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
420
+ <img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
421
+ </a>
422
+
423
+ <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
424
+ <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
425
+ ↑ Back to Top ↑
426
+ </a>
427
+ </p>