liger-kernel 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/env_report.py +46 -0
  2. liger_kernel/ops/cross_entropy.py +130 -63
  3. liger_kernel/ops/experimental/embedding.py +143 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
  5. liger_kernel/ops/geglu.py +54 -42
  6. liger_kernel/ops/kl_div.py +247 -0
  7. liger_kernel/ops/layer_norm.py +236 -0
  8. liger_kernel/ops/rms_norm.py +220 -84
  9. liger_kernel/ops/rope.py +91 -84
  10. liger_kernel/ops/swiglu.py +48 -41
  11. liger_kernel/ops/utils.py +12 -0
  12. liger_kernel/transformers/__init__.py +22 -0
  13. liger_kernel/transformers/auto_model.py +33 -0
  14. liger_kernel/transformers/cross_entropy.py +11 -1
  15. liger_kernel/transformers/experimental/embedding.py +28 -0
  16. liger_kernel/transformers/functional.py +19 -0
  17. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
  18. liger_kernel/transformers/geglu.py +4 -2
  19. liger_kernel/transformers/kl_div.py +13 -0
  20. liger_kernel/transformers/layer_norm.py +30 -0
  21. liger_kernel/transformers/model/gemma.py +138 -0
  22. liger_kernel/transformers/model/llama.py +1 -1
  23. liger_kernel/transformers/model/mistral.py +138 -0
  24. liger_kernel/transformers/model/mixtral.py +158 -0
  25. liger_kernel/transformers/model/phi3.py +136 -0
  26. liger_kernel/transformers/model/qwen2.py +135 -0
  27. liger_kernel/transformers/model/qwen2_vl.py +172 -0
  28. liger_kernel/transformers/monkey_patch.py +605 -14
  29. liger_kernel/transformers/rms_norm.py +23 -4
  30. liger_kernel/transformers/swiglu.py +24 -0
  31. liger_kernel/transformers/trainer_integration.py +2 -45
  32. liger_kernel-0.3.0.dist-info/METADATA +388 -0
  33. liger_kernel-0.3.0.dist-info/RECORD +42 -0
  34. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.1.0.dist-info/METADATA +0 -16
  36. liger_kernel-0.1.0.dist-info/RECORD +0 -27
  37. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/NOTICE +0 -0
  39. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/top_level.txt +0 -0
liger_kernel/ops/utils.py CHANGED
@@ -1,3 +1,15 @@
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
1
13
  import functools
2
14
  import importlib
3
15
  from typing import Callable
@@ -1,6 +1,28 @@
1
+ from liger_kernel.transformers.auto_model import ( # noqa: F401
2
+ AutoLigerKernelForCausalLM,
3
+ )
4
+ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
5
+ from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
+ LigerFusedLinearCrossEntropyLoss,
7
+ )
8
+ from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
9
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
1
10
  from liger_kernel.transformers.monkey_patch import ( # noqa: F401
11
+ _apply_liger_kernel,
12
+ _apply_liger_kernel_to_instance,
2
13
  apply_liger_kernel_to_gemma,
14
+ apply_liger_kernel_to_gemma2,
3
15
  apply_liger_kernel_to_llama,
4
16
  apply_liger_kernel_to_mistral,
5
17
  apply_liger_kernel_to_mixtral,
18
+ apply_liger_kernel_to_phi3,
19
+ apply_liger_kernel_to_qwen2,
20
+ apply_liger_kernel_to_qwen2_vl,
21
+ )
22
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
23
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
24
+ from liger_kernel.transformers.swiglu import ( # noqa: F401
25
+ LigerBlockSparseTop2MLP,
26
+ LigerPhi3SwiGLUMLP,
27
+ LigerSwiGLUMLP,
6
28
  )
@@ -0,0 +1,33 @@
1
+ from transformers import AutoConfig, AutoModelForCausalLM
2
+
3
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
4
+
5
+
6
+ def _get_model_config(model_dir, **model_init_kwargs):
7
+ config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
8
+ return config
9
+
10
+
11
+ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
12
+ """
13
+ This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
14
+ if applicable.
15
+ """
16
+
17
+ @classmethod
18
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
19
+ model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)
20
+
21
+ # Determine the model type and apply the Liger Kernel if applicable
22
+ # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
23
+ model_type = model_config.model_type
24
+ _apply_liger_kernel(model_type, **kwargs)
25
+
26
+ # Retain only the keyword args present in the model configuration
27
+ for k in list(kwargs.keys()):
28
+ if k not in model_config.__dict__:
29
+ del kwargs[k]
30
+
31
+ return super().from_pretrained(
32
+ pretrained_model_name_or_path, *model_args, **kwargs
33
+ )
@@ -6,6 +6,16 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
6
6
  class LigerCrossEntropyLoss(CrossEntropyLoss):
7
7
  def __init__(self, *args, **kwargs):
8
8
  super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
9
+ assert (self.label_smoothing >= 0) and (
10
+ self.label_smoothing <= 1
11
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
12
+ assert self.reduction in {
13
+ "mean",
14
+ "sum",
15
+ "none",
16
+ }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
9
17
 
10
18
  def forward(self, _input, target):
11
- return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index)
19
+ return LigerCrossEntropyFunction.apply(
20
+ _input, target, self.ignore_index, self.label_smoothing, self.reduction
21
+ )
@@ -0,0 +1,28 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
7
+
8
+
9
+ class LigerEmbedding(nn.Module):
10
+ def __init__(
11
+ self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
12
+ ):
13
+ super().__init__()
14
+ self.num_embeddings = num_embeddings
15
+ self.embedding_dim = embedding_dim
16
+ self.padding_idx = padding_idx
17
+ self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
18
+
19
+ if padding_idx is not None:
20
+ with torch.no_grad():
21
+ self.weight[padding_idx].fill_(0)
22
+
23
+ def forward(self, indices):
24
+ embedded = LigerEmbeddingFunction.apply(self.weight, indices)
25
+ if self.padding_idx is not None:
26
+ embedded = embedded.clone()
27
+ embedded[indices == self.padding_idx] = 0
28
+ return embedded
@@ -0,0 +1,19 @@
1
+ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
2
+ from liger_kernel.ops.fused_linear_cross_entropy import (
3
+ LigerFusedLinearCrossEntropyFunction,
4
+ )
5
+ from liger_kernel.ops.geglu import LigerGELUMulFunction
6
+ from liger_kernel.ops.kl_div import LigerKLDivLossFunction
7
+ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
8
+ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
9
+ from liger_kernel.ops.rope import LigerRopeFunction
10
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
11
+
12
+ liger_swiglu = LigerSiLUMulFunction.apply
13
+ liger_cross_entropy = LigerCrossEntropyFunction.apply
14
+ liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
15
+ liger_geglu = LigerGELUMulFunction.apply
16
+ liger_rms_norm = LigerRMSNormFunction.apply
17
+ liger_rope = LigerRopeFunction.apply
18
+ liger_layer_norm = LigerLayerNormFunction.apply
19
+ liger_kl_div = LigerKLDivLossFunction.apply
@@ -9,7 +9,13 @@ class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
9
9
  def __init__(self, *args, **kwargs):
10
10
  super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
11
11
 
12
- def forward(self, lin_weight, _input, target):
12
+ def forward(self, lin_weight, _input, target, bias=None):
13
13
  return LigerFusedLinearCrossEntropyFunction.apply(
14
- _input, lin_weight, target, self.ignore_index
14
+ _input,
15
+ lin_weight,
16
+ target,
17
+ bias,
18
+ self.ignore_index,
19
+ self.label_smoothing,
20
+ self.reduction,
15
21
  )
@@ -13,8 +13,10 @@ class LigerGEGLUMLP(nn.Module):
13
13
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
14
14
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
15
15
  # TODO: support exact GELU
16
- if config.hidden_act not in ["gelu_pytorch_tanh"]:
17
- raise ValueError(f"Activation function {config.hidden_act} not supported.")
16
+ # Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh`
17
+ # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175
18
+ # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46
19
+ # So we can safely assume we use tanh approximation form all the time
18
20
 
19
21
  def forward(self, x):
20
22
 
@@ -0,0 +1,13 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.kl_div import LigerKLDivLossFunction
4
+
5
+
6
+ class LigerKLDIVLoss(nn.KLDivLoss):
7
+ def __init__(self, *args, **kwargs):
8
+ super(LigerKLDIVLoss, self).__init__(*args, **kwargs)
9
+
10
+ def forward(self, y_pred, y_true):
11
+ return LigerKLDivLossFunction.apply(
12
+ y_pred, y_true, self.reduction, self.log_target
13
+ )
@@ -0,0 +1,30 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
5
+
6
+
7
+ class LigerLayerNorm(nn.Module):
8
+ def __init__(self, hidden_size, eps=1e-6, bias=False, init_fn="ones"):
9
+ super().__init__()
10
+ assert init_fn in [
11
+ "ones",
12
+ "zeros",
13
+ ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
14
+ self.hidden_size = hidden_size
15
+ self.eps = eps
16
+ self.weight = nn.Parameter(
17
+ torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
+ )
19
+ self.bias = nn.Parameter(
20
+ torch.randn(hidden_size) if bias else torch.zeros(hidden_size)
21
+ )
22
+ self.variance_epsilon = eps
23
+
24
+ def forward(self, hidden_states):
25
+ return LigerLayerNormFunction.apply(
26
+ hidden_states, self.weight, self.bias, self.variance_epsilon
27
+ )
28
+
29
+ def extra_repr(self):
30
+ return f"{self.hidden_size}, eps={self.eps}"
@@ -0,0 +1,138 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers.cache_utils import Cache
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.models.gemma.modeling_gemma import (
8
+ _CONFIG_FOR_DOC,
9
+ GEMMA_INPUTS_DOCSTRING,
10
+ )
11
+ from transformers.utils import (
12
+ add_start_docstrings_to_model_forward,
13
+ replace_return_docstrings,
14
+ )
15
+
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
17
+ LigerFusedLinearCrossEntropyLoss,
18
+ )
19
+
20
+
21
+ @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
+ @replace_return_docstrings(
23
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
+ )
25
+ def lce_forward(
26
+ self,
27
+ input_ids: torch.LongTensor = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ position_ids: Optional[torch.LongTensor] = None,
30
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
31
+ inputs_embeds: Optional[torch.FloatTensor] = None,
32
+ labels: Optional[torch.LongTensor] = None,
33
+ use_cache: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ cache_position: Optional[torch.LongTensor] = None,
38
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
39
+ r"""
40
+
41
+ copy paste transformers.models.gemma.modeling_gemma causalLM with loss replaced with liger fused cross entropy
42
+
43
+ Args:
44
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
+
49
+ Returns:
50
+
51
+ Example:
52
+
53
+ ```python
54
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
55
+
56
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
57
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
58
+
59
+ >>> prompt = "What is your favorite condiment?"
60
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
61
+
62
+ >>> # Generate
63
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
64
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65
+ "What is your favorite condiment?"
66
+ ```"""
67
+ output_attentions = (
68
+ output_attentions
69
+ if output_attentions is not None
70
+ else self.config.output_attentions
71
+ )
72
+ output_hidden_states = (
73
+ output_hidden_states
74
+ if output_hidden_states is not None
75
+ else self.config.output_hidden_states
76
+ )
77
+ return_dict = (
78
+ return_dict if return_dict is not None else self.config.use_return_dict
79
+ )
80
+
81
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
+ outputs = self.model(
83
+ input_ids=input_ids,
84
+ attention_mask=attention_mask,
85
+ position_ids=position_ids,
86
+ past_key_values=past_key_values,
87
+ inputs_embeds=inputs_embeds,
88
+ use_cache=use_cache,
89
+ output_attentions=output_attentions,
90
+ output_hidden_states=output_hidden_states,
91
+ return_dict=return_dict,
92
+ cache_position=cache_position,
93
+ )
94
+
95
+ hidden_states = outputs[0]
96
+
97
+ loss = None
98
+ logits = None
99
+
100
+ if self.training and (labels is not None):
101
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
102
+ shift_labels = labels[..., 1:].contiguous()
103
+
104
+ # flatten
105
+
106
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
107
+ shift_labels = shift_labels.view(-1)
108
+
109
+ lce = LigerFusedLinearCrossEntropyLoss()
110
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
+
112
+ else:
113
+ logits = self.lm_head(hidden_states)
114
+ if labels is not None:
115
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
116
+ logits = logits.float()
117
+ # Shift so that tokens < n predict n
118
+ shift_logits = logits[..., :-1, :].contiguous()
119
+ shift_labels = labels[..., 1:].contiguous()
120
+ # Flatten the tokens
121
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
122
+ shift_labels = shift_labels.view(-1)
123
+ # Ensure tensors are on the same device
124
+ shift_labels = shift_labels.to(shift_logits.device)
125
+ loss_fct = CrossEntropyLoss()
126
+ loss = loss_fct(shift_logits, shift_labels)
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ return (loss,) + output if loss is not None else output
131
+
132
+ return CausalLMOutputWithPast(
133
+ loss=loss,
134
+ logits=logits,
135
+ past_key_values=outputs.past_key_values,
136
+ hidden_states=outputs.hidden_states,
137
+ attentions=outputs.attentions,
138
+ )
@@ -97,7 +97,7 @@ def lce_forward(
97
97
  loss = None
98
98
  logits = None
99
99
 
100
- if self.training:
100
+ if self.training and (labels is not None):
101
101
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
102
102
  shift_labels = labels[..., 1:].contiguous()
103
103
 
@@ -0,0 +1,138 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers.cache_utils import Cache
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.models.mistral.modeling_mistral import (
8
+ _CONFIG_FOR_DOC,
9
+ MISTRAL_INPUTS_DOCSTRING,
10
+ )
11
+ from transformers.utils import (
12
+ add_start_docstrings_to_model_forward,
13
+ replace_return_docstrings,
14
+ )
15
+
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
17
+ LigerFusedLinearCrossEntropyLoss,
18
+ )
19
+
20
+
21
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
22
+ @replace_return_docstrings(
23
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
+ )
25
+ def lce_forward(
26
+ self,
27
+ input_ids: torch.LongTensor = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ position_ids: Optional[torch.LongTensor] = None,
30
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
31
+ inputs_embeds: Optional[torch.FloatTensor] = None,
32
+ labels: Optional[torch.LongTensor] = None,
33
+ use_cache: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ cache_position: Optional[torch.LongTensor] = None,
38
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
39
+ r"""
40
+ Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
41
+
42
+
43
+ Args:
44
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
+
49
+ Returns:
50
+
51
+ Example:
52
+
53
+ ```python
54
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
55
+
56
+ >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
57
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
58
+
59
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
60
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
61
+
62
+ >>> # Generate
63
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
64
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
66
+ ```"""
67
+
68
+ output_attentions = (
69
+ output_attentions
70
+ if output_attentions is not None
71
+ else self.config.output_attentions
72
+ )
73
+ output_hidden_states = (
74
+ output_hidden_states
75
+ if output_hidden_states is not None
76
+ else self.config.output_hidden_states
77
+ )
78
+ return_dict = (
79
+ return_dict if return_dict is not None else self.config.use_return_dict
80
+ )
81
+
82
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
83
+ outputs = self.model(
84
+ input_ids=input_ids,
85
+ attention_mask=attention_mask,
86
+ position_ids=position_ids,
87
+ past_key_values=past_key_values,
88
+ inputs_embeds=inputs_embeds,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ output_hidden_states=output_hidden_states,
92
+ return_dict=return_dict,
93
+ cache_position=cache_position,
94
+ )
95
+
96
+ hidden_states = outputs[0]
97
+
98
+ loss = None
99
+ logits = None
100
+
101
+ if self.training and (labels is not None):
102
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
103
+ shift_labels = labels[..., 1:].contiguous()
104
+
105
+ # flatten tokens
106
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
107
+ shift_labels = shift_labels.view(-1)
108
+
109
+ lce = LigerFusedLinearCrossEntropyLoss()
110
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
+
112
+ else:
113
+ logits = self.lm_head(hidden_states)
114
+ if labels is not None:
115
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
116
+ logits = logits.float()
117
+ # Shift so that tokens < n predict n
118
+ shift_logits = logits[..., :-1, :].contiguous()
119
+ shift_labels = labels[..., 1:].contiguous()
120
+ # Flatten the tokens
121
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
122
+ shift_labels = shift_labels.view(-1)
123
+ # Ensure tensors are on the same device
124
+ shift_labels = shift_labels.to(shift_logits.device)
125
+ loss_fct = CrossEntropyLoss()
126
+ loss = loss_fct(shift_logits, shift_labels)
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ return (loss,) + output if loss is not None else output
131
+
132
+ return CausalLMOutputWithPast(
133
+ loss=loss,
134
+ logits=logits,
135
+ past_key_values=outputs.past_key_values,
136
+ hidden_states=outputs.hidden_states,
137
+ attentions=outputs.attentions,
138
+ )
@@ -0,0 +1,158 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast
6
+ from transformers.models.mixtral.modeling_mixtral import (
7
+ _CONFIG_FOR_DOC,
8
+ MIXTRAL_INPUTS_DOCSTRING,
9
+ load_balancing_loss_func,
10
+ )
11
+ from transformers.utils import (
12
+ add_start_docstrings_to_model_forward,
13
+ replace_return_docstrings,
14
+ )
15
+
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
17
+ LigerFusedLinearCrossEntropyLoss,
18
+ )
19
+
20
+
21
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
+ @replace_return_docstrings(
23
+ output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
+ )
25
+ def lce_forward(
26
+ self,
27
+ input_ids: torch.LongTensor = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ position_ids: Optional[torch.LongTensor] = None,
30
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
31
+ inputs_embeds: Optional[torch.FloatTensor] = None,
32
+ labels: Optional[torch.LongTensor] = None,
33
+ use_cache: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ output_router_logits: Optional[bool] = None,
37
+ return_dict: Optional[bool] = None,
38
+ cache_position: Optional[torch.LongTensor] = None,
39
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
40
+ r"""
41
+ Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
42
+
43
+
44
+ Args:
45
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
46
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
47
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
48
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
49
+
50
+ Returns:
51
+
52
+ Example:
53
+
54
+ ```python
55
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
56
+
57
+ >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
58
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
59
+
60
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
61
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
62
+
63
+ >>> # Generate
64
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
65
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
66
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
67
+ ```"""
68
+
69
+ output_attentions = (
70
+ output_attentions
71
+ if output_attentions is not None
72
+ else self.config.output_attentions
73
+ )
74
+ output_router_logits = (
75
+ output_router_logits
76
+ if output_router_logits is not None
77
+ else self.config.output_router_logits
78
+ )
79
+
80
+ output_hidden_states = (
81
+ output_hidden_states
82
+ if output_hidden_states is not None
83
+ else self.config.output_hidden_states
84
+ )
85
+ return_dict = (
86
+ return_dict if return_dict is not None else self.config.use_return_dict
87
+ )
88
+
89
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
90
+ outputs = self.model(
91
+ input_ids=input_ids,
92
+ attention_mask=attention_mask,
93
+ position_ids=position_ids,
94
+ past_key_values=past_key_values,
95
+ inputs_embeds=inputs_embeds,
96
+ use_cache=use_cache,
97
+ output_attentions=output_attentions,
98
+ output_hidden_states=output_hidden_states,
99
+ output_router_logits=output_router_logits,
100
+ return_dict=return_dict,
101
+ cache_position=cache_position,
102
+ )
103
+
104
+ hidden_states = outputs[0]
105
+ logits = self.lm_head(hidden_states)
106
+ logits = logits.float()
107
+
108
+ loss = None
109
+ if self.training and (labels is not None):
110
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
111
+ shift_labels = labels[..., 1:].contiguous()
112
+ # Flatten the tokens
113
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
114
+ shift_labels = shift_labels.view(-1)
115
+
116
+ lce = LigerFusedLinearCrossEntropyLoss()
117
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
118
+ elif labels is not None:
119
+ # Shift so that tokens < n predict n
120
+ shift_logits = logits[..., :-1, :].contiguous()
121
+ shift_labels = labels[..., 1:].contiguous()
122
+ # Flatten the tokens
123
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
124
+ shift_labels = shift_labels.view(-1)
125
+ # Enable model parallelism
126
+ shift_labels = shift_labels.to(shift_logits.device)
127
+
128
+ loss_fct = CrossEntropyLoss()
129
+ loss = loss_fct(logits.weight, shift_labels)
130
+
131
+ aux_loss = None
132
+ if output_router_logits:
133
+ aux_loss = load_balancing_loss_func(
134
+ outputs.router_logits if return_dict else outputs[-1],
135
+ self.num_experts,
136
+ self.num_experts_per_tok,
137
+ attention_mask,
138
+ )
139
+ if labels is not None:
140
+ loss += self.router_aux_loss_coef * aux_loss.to(
141
+ loss.device
142
+ ) # make sure to reside in the same device
143
+
144
+ if not return_dict:
145
+ output = (logits,) + outputs[1:]
146
+ if output_router_logits:
147
+ output = (aux_loss,) + output
148
+ return (loss,) + output if loss is not None else output
149
+
150
+ return MoeCausalLMOutputWithPast(
151
+ loss=loss,
152
+ aux_loss=aux_loss,
153
+ logits=logits,
154
+ past_key_values=outputs.past_key_values,
155
+ hidden_states=outputs.hidden_states,
156
+ attentions=outputs.attentions,
157
+ router_logits=outputs.router_logits,
158
+ )