liger-kernel 0.1.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/env_report.py +46 -0
  2. liger_kernel/ops/cross_entropy.py +130 -63
  3. liger_kernel/ops/experimental/embedding.py +143 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
  5. liger_kernel/ops/geglu.py +56 -44
  6. liger_kernel/ops/kl_div.py +258 -0
  7. liger_kernel/ops/layer_norm.py +236 -0
  8. liger_kernel/ops/rms_norm.py +220 -84
  9. liger_kernel/ops/rope.py +91 -84
  10. liger_kernel/ops/swiglu.py +50 -43
  11. liger_kernel/ops/utils.py +12 -0
  12. liger_kernel/transformers/__init__.py +22 -0
  13. liger_kernel/transformers/auto_model.py +45 -0
  14. liger_kernel/transformers/cross_entropy.py +11 -1
  15. liger_kernel/transformers/experimental/embedding.py +28 -0
  16. liger_kernel/transformers/functional.py +19 -0
  17. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
  18. liger_kernel/transformers/geglu.py +4 -2
  19. liger_kernel/transformers/kl_div.py +14 -0
  20. liger_kernel/transformers/layer_norm.py +30 -0
  21. liger_kernel/transformers/model/gemma.py +138 -0
  22. liger_kernel/transformers/model/llama.py +1 -1
  23. liger_kernel/transformers/model/mistral.py +138 -0
  24. liger_kernel/transformers/model/mixtral.py +158 -0
  25. liger_kernel/transformers/model/phi3.py +136 -0
  26. liger_kernel/transformers/model/qwen2.py +135 -0
  27. liger_kernel/transformers/model/qwen2_vl.py +172 -0
  28. liger_kernel/transformers/monkey_patch.py +579 -14
  29. liger_kernel/transformers/rms_norm.py +23 -4
  30. liger_kernel/transformers/swiglu.py +24 -0
  31. liger_kernel/transformers/trainer_integration.py +2 -45
  32. liger_kernel-0.3.1.dist-info/METADATA +395 -0
  33. liger_kernel-0.3.1.dist-info/RECORD +42 -0
  34. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.1.0.dist-info/METADATA +0 -16
  36. liger_kernel-0.1.0.dist-info/RECORD +0 -27
  37. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/NOTICE +0 -0
  39. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@ def silu(x):
14
14
  def _swiglu_forward_kernel(
15
15
  a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
16
  ):
17
- program_id = tl.program_id(0)
17
+ program_id = tl.program_id(0).cast(tl.int64)
18
18
 
19
19
  # locate start index
20
20
  a_ptr += program_id * stride
@@ -35,7 +35,7 @@ def _swiglu_forward_kernel(
35
35
  def _swiglu_backward_kernel(
36
36
  dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
37
  ):
38
- program_id = tl.program_id(0)
38
+ program_id = tl.program_id(0).cast(tl.int64)
39
39
 
40
40
  # locate start index
41
41
  dc_ptr += program_id * stride
@@ -60,54 +60,61 @@ def _swiglu_backward_kernel(
60
60
  tl.store(b_ptr + col_offsets, db_row, mask=mask)
61
61
 
62
62
 
63
+ def swiglu_forward(a, b):
64
+ ori_shape = a.shape
65
+
66
+ n_cols = ori_shape[-1]
67
+ a = a.view(-1, n_cols)
68
+ b = b.view(-1, n_cols)
69
+ c = torch.empty_like(a)
70
+ n_rows = a.shape[0]
71
+
72
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
73
+
74
+ _swiglu_forward_kernel[(n_rows,)](
75
+ a,
76
+ b,
77
+ c,
78
+ c.stride(-2),
79
+ n_cols=n_cols,
80
+ BLOCK_SIZE=BLOCK_SIZE,
81
+ num_warps=num_warps,
82
+ )
83
+ return a, b, c.view(*ori_shape)
84
+
85
+
86
+ def swiglu_backward(a, b, dc):
87
+
88
+ ori_shape = dc.shape
89
+ n_cols = ori_shape[-1]
90
+ dc = dc.view(-1, n_cols)
91
+ n_rows = dc.shape[0]
92
+
93
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
94
+
95
+ _swiglu_backward_kernel[(n_rows,)](
96
+ dc,
97
+ a,
98
+ b,
99
+ dc.stride(-2),
100
+ n_cols=n_cols,
101
+ BLOCK_SIZE=BLOCK_SIZE,
102
+ num_warps=num_warps,
103
+ )
104
+ return a.view(*ori_shape), b.view(*ori_shape)
105
+
106
+
63
107
  class LigerSiLUMulFunction(torch.autograd.Function):
64
108
  @staticmethod
65
109
  @ensure_contiguous
66
110
  def forward(ctx, a, b):
67
- ori_shape = a.shape
68
-
69
- n_cols = ori_shape[-1]
70
- a = a.view(-1, n_cols)
71
- b = b.view(-1, n_cols)
72
- c = torch.zeros_like(a)
73
- n_rows = a.shape[0]
74
-
75
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
76
-
77
- _swiglu_forward_kernel[(n_rows,)](
78
- a,
79
- b,
80
- c,
81
- c.stride(-2),
82
- n_cols=n_cols,
83
- BLOCK_SIZE=BLOCK_SIZE,
84
- num_warps=num_warps,
85
- )
86
-
111
+ a, b, c = swiglu_forward(a, b)
87
112
  ctx.save_for_backward(a, b)
88
-
89
- return c.view(*ori_shape)
113
+ return c
90
114
 
91
115
  @staticmethod
92
116
  @ensure_contiguous
93
117
  def backward(ctx, dc):
94
-
95
- ori_shape = dc.shape
96
- n_cols = ori_shape[-1]
97
- dc = dc.view(-1, n_cols)
98
118
  a, b = ctx.saved_tensors
99
- n_rows = dc.shape[0]
100
-
101
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
102
-
103
- _swiglu_backward_kernel[(n_rows,)](
104
- dc,
105
- a,
106
- b,
107
- dc.stride(-2),
108
- n_cols=n_cols,
109
- BLOCK_SIZE=BLOCK_SIZE,
110
- num_warps=num_warps,
111
- )
112
-
113
- return a.view(*ori_shape), b.view(*ori_shape)
119
+ a, b = swiglu_backward(a, b, dc)
120
+ return a, b
liger_kernel/ops/utils.py CHANGED
@@ -1,3 +1,15 @@
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
1
13
  import functools
2
14
  import importlib
3
15
  from typing import Callable
@@ -1,6 +1,28 @@
1
+ from liger_kernel.transformers.auto_model import ( # noqa: F401
2
+ AutoLigerKernelForCausalLM,
3
+ )
4
+ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
5
+ from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
+ LigerFusedLinearCrossEntropyLoss,
7
+ )
8
+ from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
9
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
1
10
  from liger_kernel.transformers.monkey_patch import ( # noqa: F401
11
+ _apply_liger_kernel,
12
+ _apply_liger_kernel_to_instance,
2
13
  apply_liger_kernel_to_gemma,
14
+ apply_liger_kernel_to_gemma2,
3
15
  apply_liger_kernel_to_llama,
4
16
  apply_liger_kernel_to_mistral,
5
17
  apply_liger_kernel_to_mixtral,
18
+ apply_liger_kernel_to_phi3,
19
+ apply_liger_kernel_to_qwen2,
20
+ apply_liger_kernel_to_qwen2_vl,
21
+ )
22
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
23
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
24
+ from liger_kernel.transformers.swiglu import ( # noqa: F401
25
+ LigerBlockSparseTop2MLP,
26
+ LigerPhi3SwiGLUMLP,
27
+ LigerSwiGLUMLP,
6
28
  )
@@ -0,0 +1,45 @@
1
+ import inspect
2
+
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
+
5
+ from liger_kernel.transformers.monkey_patch import (
6
+ MODEL_TYPE_TO_APPLY_LIGER_FN,
7
+ _apply_liger_kernel,
8
+ )
9
+
10
+
11
+ def _get_model_config(model_dir, **model_init_kwargs):
12
+ config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
13
+ return config
14
+
15
+
16
+ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
17
+ """
18
+ This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
19
+ if applicable.
20
+ """
21
+
22
+ @classmethod
23
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
24
+ model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)
25
+
26
+ # Determine the model type and apply the Liger Kernel if applicable
27
+ # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
28
+ model_type = model_config.model_type
29
+
30
+ _apply_liger_kernel(model_type, **kwargs)
31
+
32
+ # Filter out kwargs that were passed to the apply_liger_* function, which will cause
33
+ # model initialization errors otherwise
34
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35
+ apply_fn_signature = inspect.signature(apply_fn)
36
+
37
+ applicable_kwargs = {
38
+ key: value
39
+ for key, value in kwargs.items()
40
+ if key not in apply_fn_signature.parameters
41
+ }
42
+
43
+ return super().from_pretrained(
44
+ pretrained_model_name_or_path, *model_args, **applicable_kwargs
45
+ )
@@ -6,6 +6,16 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
6
6
  class LigerCrossEntropyLoss(CrossEntropyLoss):
7
7
  def __init__(self, *args, **kwargs):
8
8
  super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
9
+ assert (self.label_smoothing >= 0) and (
10
+ self.label_smoothing <= 1
11
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
12
+ assert self.reduction in {
13
+ "mean",
14
+ "sum",
15
+ "none",
16
+ }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
9
17
 
10
18
  def forward(self, _input, target):
11
- return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index)
19
+ return LigerCrossEntropyFunction.apply(
20
+ _input, target, self.ignore_index, self.label_smoothing, self.reduction
21
+ )
@@ -0,0 +1,28 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
7
+
8
+
9
+ class LigerEmbedding(nn.Module):
10
+ def __init__(
11
+ self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
12
+ ):
13
+ super().__init__()
14
+ self.num_embeddings = num_embeddings
15
+ self.embedding_dim = embedding_dim
16
+ self.padding_idx = padding_idx
17
+ self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
18
+
19
+ if padding_idx is not None:
20
+ with torch.no_grad():
21
+ self.weight[padding_idx].fill_(0)
22
+
23
+ def forward(self, indices):
24
+ embedded = LigerEmbeddingFunction.apply(self.weight, indices)
25
+ if self.padding_idx is not None:
26
+ embedded = embedded.clone()
27
+ embedded[indices == self.padding_idx] = 0
28
+ return embedded
@@ -0,0 +1,19 @@
1
+ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
2
+ from liger_kernel.ops.fused_linear_cross_entropy import (
3
+ LigerFusedLinearCrossEntropyFunction,
4
+ )
5
+ from liger_kernel.ops.geglu import LigerGELUMulFunction
6
+ from liger_kernel.ops.kl_div import LigerKLDivLossFunction
7
+ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
8
+ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
9
+ from liger_kernel.ops.rope import LigerRopeFunction
10
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
11
+
12
+ liger_swiglu = LigerSiLUMulFunction.apply
13
+ liger_cross_entropy = LigerCrossEntropyFunction.apply
14
+ liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
15
+ liger_geglu = LigerGELUMulFunction.apply
16
+ liger_rms_norm = LigerRMSNormFunction.apply
17
+ liger_rope = LigerRopeFunction.apply
18
+ liger_layer_norm = LigerLayerNormFunction.apply
19
+ liger_kl_div = LigerKLDivLossFunction.apply
@@ -9,7 +9,13 @@ class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
9
9
  def __init__(self, *args, **kwargs):
10
10
  super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
11
11
 
12
- def forward(self, lin_weight, _input, target):
12
+ def forward(self, lin_weight, _input, target, bias=None):
13
13
  return LigerFusedLinearCrossEntropyFunction.apply(
14
- _input, lin_weight, target, self.ignore_index
14
+ _input,
15
+ lin_weight,
16
+ target,
17
+ bias,
18
+ self.ignore_index,
19
+ self.label_smoothing,
20
+ self.reduction,
15
21
  )
@@ -13,8 +13,10 @@ class LigerGEGLUMLP(nn.Module):
13
13
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
14
14
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
15
15
  # TODO: support exact GELU
16
- if config.hidden_act not in ["gelu_pytorch_tanh"]:
17
- raise ValueError(f"Activation function {config.hidden_act} not supported.")
16
+ # Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh`
17
+ # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175
18
+ # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46
19
+ # So we can safely assume we use tanh approximation form all the time
18
20
 
19
21
  def forward(self, x):
20
22
 
@@ -0,0 +1,14 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.kl_div import LigerKLDivLossFunction
4
+
5
+
6
+ class LigerKLDIVLoss(nn.KLDivLoss):
7
+ def __init__(self, eps: float = 1e-10, *args, **kwargs):
8
+ super(LigerKLDIVLoss, self).__init__(*args, **kwargs)
9
+ self.eps = eps
10
+
11
+ def forward(self, y_pred, y_true):
12
+ return LigerKLDivLossFunction.apply(
13
+ y_pred, y_true, self.reduction, self.log_target, self.eps
14
+ )
@@ -0,0 +1,30 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
5
+
6
+
7
+ class LigerLayerNorm(nn.Module):
8
+ def __init__(self, hidden_size, eps=1e-6, bias=False, init_fn="ones"):
9
+ super().__init__()
10
+ assert init_fn in [
11
+ "ones",
12
+ "zeros",
13
+ ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
14
+ self.hidden_size = hidden_size
15
+ self.eps = eps
16
+ self.weight = nn.Parameter(
17
+ torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
+ )
19
+ self.bias = nn.Parameter(
20
+ torch.randn(hidden_size) if bias else torch.zeros(hidden_size)
21
+ )
22
+ self.variance_epsilon = eps
23
+
24
+ def forward(self, hidden_states):
25
+ return LigerLayerNormFunction.apply(
26
+ hidden_states, self.weight, self.bias, self.variance_epsilon
27
+ )
28
+
29
+ def extra_repr(self):
30
+ return f"{self.hidden_size}, eps={self.eps}"
@@ -0,0 +1,138 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers.cache_utils import Cache
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.models.gemma.modeling_gemma import (
8
+ _CONFIG_FOR_DOC,
9
+ GEMMA_INPUTS_DOCSTRING,
10
+ )
11
+ from transformers.utils import (
12
+ add_start_docstrings_to_model_forward,
13
+ replace_return_docstrings,
14
+ )
15
+
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
17
+ LigerFusedLinearCrossEntropyLoss,
18
+ )
19
+
20
+
21
+ @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
+ @replace_return_docstrings(
23
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
+ )
25
+ def lce_forward(
26
+ self,
27
+ input_ids: torch.LongTensor = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ position_ids: Optional[torch.LongTensor] = None,
30
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
31
+ inputs_embeds: Optional[torch.FloatTensor] = None,
32
+ labels: Optional[torch.LongTensor] = None,
33
+ use_cache: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ cache_position: Optional[torch.LongTensor] = None,
38
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
39
+ r"""
40
+
41
+ copy paste transformers.models.gemma.modeling_gemma causalLM with loss replaced with liger fused cross entropy
42
+
43
+ Args:
44
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
+
49
+ Returns:
50
+
51
+ Example:
52
+
53
+ ```python
54
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
55
+
56
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
57
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
58
+
59
+ >>> prompt = "What is your favorite condiment?"
60
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
61
+
62
+ >>> # Generate
63
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
64
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65
+ "What is your favorite condiment?"
66
+ ```"""
67
+ output_attentions = (
68
+ output_attentions
69
+ if output_attentions is not None
70
+ else self.config.output_attentions
71
+ )
72
+ output_hidden_states = (
73
+ output_hidden_states
74
+ if output_hidden_states is not None
75
+ else self.config.output_hidden_states
76
+ )
77
+ return_dict = (
78
+ return_dict if return_dict is not None else self.config.use_return_dict
79
+ )
80
+
81
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
+ outputs = self.model(
83
+ input_ids=input_ids,
84
+ attention_mask=attention_mask,
85
+ position_ids=position_ids,
86
+ past_key_values=past_key_values,
87
+ inputs_embeds=inputs_embeds,
88
+ use_cache=use_cache,
89
+ output_attentions=output_attentions,
90
+ output_hidden_states=output_hidden_states,
91
+ return_dict=return_dict,
92
+ cache_position=cache_position,
93
+ )
94
+
95
+ hidden_states = outputs[0]
96
+
97
+ loss = None
98
+ logits = None
99
+
100
+ if self.training and (labels is not None):
101
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
102
+ shift_labels = labels[..., 1:].contiguous()
103
+
104
+ # flatten
105
+
106
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
107
+ shift_labels = shift_labels.view(-1)
108
+
109
+ lce = LigerFusedLinearCrossEntropyLoss()
110
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
+
112
+ else:
113
+ logits = self.lm_head(hidden_states)
114
+ if labels is not None:
115
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
116
+ logits = logits.float()
117
+ # Shift so that tokens < n predict n
118
+ shift_logits = logits[..., :-1, :].contiguous()
119
+ shift_labels = labels[..., 1:].contiguous()
120
+ # Flatten the tokens
121
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
122
+ shift_labels = shift_labels.view(-1)
123
+ # Ensure tensors are on the same device
124
+ shift_labels = shift_labels.to(shift_logits.device)
125
+ loss_fct = CrossEntropyLoss()
126
+ loss = loss_fct(shift_logits, shift_labels)
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ return (loss,) + output if loss is not None else output
131
+
132
+ return CausalLMOutputWithPast(
133
+ loss=loss,
134
+ logits=logits,
135
+ past_key_values=outputs.past_key_values,
136
+ hidden_states=outputs.hidden_states,
137
+ attentions=outputs.attentions,
138
+ )
@@ -97,7 +97,7 @@ def lce_forward(
97
97
  loss = None
98
98
  logits = None
99
99
 
100
- if self.training:
100
+ if self.training and (labels is not None):
101
101
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
102
102
  shift_labels = labels[..., 1:].contiguous()
103
103
 
@@ -0,0 +1,138 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers.cache_utils import Cache
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.models.mistral.modeling_mistral import (
8
+ _CONFIG_FOR_DOC,
9
+ MISTRAL_INPUTS_DOCSTRING,
10
+ )
11
+ from transformers.utils import (
12
+ add_start_docstrings_to_model_forward,
13
+ replace_return_docstrings,
14
+ )
15
+
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
17
+ LigerFusedLinearCrossEntropyLoss,
18
+ )
19
+
20
+
21
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
22
+ @replace_return_docstrings(
23
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
+ )
25
+ def lce_forward(
26
+ self,
27
+ input_ids: torch.LongTensor = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ position_ids: Optional[torch.LongTensor] = None,
30
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
31
+ inputs_embeds: Optional[torch.FloatTensor] = None,
32
+ labels: Optional[torch.LongTensor] = None,
33
+ use_cache: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ cache_position: Optional[torch.LongTensor] = None,
38
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
39
+ r"""
40
+ Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
41
+
42
+
43
+ Args:
44
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
+
49
+ Returns:
50
+
51
+ Example:
52
+
53
+ ```python
54
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
55
+
56
+ >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
57
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
58
+
59
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
60
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
61
+
62
+ >>> # Generate
63
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
64
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
66
+ ```"""
67
+
68
+ output_attentions = (
69
+ output_attentions
70
+ if output_attentions is not None
71
+ else self.config.output_attentions
72
+ )
73
+ output_hidden_states = (
74
+ output_hidden_states
75
+ if output_hidden_states is not None
76
+ else self.config.output_hidden_states
77
+ )
78
+ return_dict = (
79
+ return_dict if return_dict is not None else self.config.use_return_dict
80
+ )
81
+
82
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
83
+ outputs = self.model(
84
+ input_ids=input_ids,
85
+ attention_mask=attention_mask,
86
+ position_ids=position_ids,
87
+ past_key_values=past_key_values,
88
+ inputs_embeds=inputs_embeds,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ output_hidden_states=output_hidden_states,
92
+ return_dict=return_dict,
93
+ cache_position=cache_position,
94
+ )
95
+
96
+ hidden_states = outputs[0]
97
+
98
+ loss = None
99
+ logits = None
100
+
101
+ if self.training and (labels is not None):
102
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
103
+ shift_labels = labels[..., 1:].contiguous()
104
+
105
+ # flatten tokens
106
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
107
+ shift_labels = shift_labels.view(-1)
108
+
109
+ lce = LigerFusedLinearCrossEntropyLoss()
110
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
+
112
+ else:
113
+ logits = self.lm_head(hidden_states)
114
+ if labels is not None:
115
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
116
+ logits = logits.float()
117
+ # Shift so that tokens < n predict n
118
+ shift_logits = logits[..., :-1, :].contiguous()
119
+ shift_labels = labels[..., 1:].contiguous()
120
+ # Flatten the tokens
121
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
122
+ shift_labels = shift_labels.view(-1)
123
+ # Ensure tensors are on the same device
124
+ shift_labels = shift_labels.to(shift_logits.device)
125
+ loss_fct = CrossEntropyLoss()
126
+ loss = loss_fct(shift_logits, shift_labels)
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ return (loss,) + output if loss is not None else output
131
+
132
+ return CausalLMOutputWithPast(
133
+ loss=loss,
134
+ logits=logits,
135
+ past_key_values=outputs.past_key_values,
136
+ hidden_states=outputs.hidden_states,
137
+ attentions=outputs.attentions,
138
+ )