liger-kernel 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. liger_kernel/ops/cross_entropy.py +5 -39
  2. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  3. liger_kernel/ops/fused_linear_cross_entropy.py +12 -9
  4. liger_kernel/ops/fused_linear_jsd.py +245 -0
  5. liger_kernel/ops/geglu.py +2 -2
  6. liger_kernel/ops/jsd.py +176 -0
  7. liger_kernel/ops/kl_div.py +2 -2
  8. liger_kernel/ops/rms_norm.py +67 -42
  9. liger_kernel/ops/swiglu.py +2 -2
  10. liger_kernel/ops/utils.py +62 -1
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/functional.py +4 -0
  13. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  14. liger_kernel/transformers/jsd.py +75 -0
  15. liger_kernel/transformers/model/gemma.py +124 -1
  16. liger_kernel/transformers/model/llama.py +135 -4
  17. liger_kernel/transformers/model/mistral.py +3 -0
  18. liger_kernel/transformers/model/mixtral.py +153 -2
  19. liger_kernel/transformers/model/mllama.py +274 -0
  20. liger_kernel/transformers/model/phi3.py +140 -2
  21. liger_kernel/transformers/model/qwen2.py +123 -2
  22. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  23. liger_kernel/transformers/monkey_patch.py +158 -7
  24. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/METADATA +60 -28
  25. liger_kernel-0.4.0.dist-info/NOTICE +58 -0
  26. liger_kernel-0.4.0.dist-info/RECORD +48 -0
  27. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/WHEEL +1 -1
  28. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  29. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  30. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/LICENSE +0 -0
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/top_level.txt +0 -0
liger_kernel/ops/utils.py CHANGED
@@ -12,13 +12,19 @@ Modifications made by Yanning Chen, 2024.
12
12
 
13
13
  import functools
14
14
  import importlib
15
+ import operator
15
16
  from typing import Callable
16
17
 
17
18
  import torch
18
19
  import triton
20
+ import triton.language as tl
19
21
  from packaging.version import Version
20
22
 
21
23
 
24
+ def is_hip() -> bool:
25
+ return torch.version.hip is not None
26
+
27
+
22
28
  def ensure_contiguous(fn):
23
29
  @functools.wraps(fn)
24
30
  def wrapper(ctx, *args, **kwargs):
@@ -45,7 +51,7 @@ def calculate_settings(n):
45
51
 
46
52
  num_warps = 4
47
53
  if BLOCK_SIZE >= 32768:
48
- num_warps = 32
54
+ num_warps = 32 if not is_hip() else 16
49
55
  elif BLOCK_SIZE >= 8192:
50
56
  num_warps = 16
51
57
  elif BLOCK_SIZE >= 2048:
@@ -60,3 +66,58 @@ def compare_version(package: str, operator: Callable, target: str):
60
66
  return False
61
67
  pkg_version = Version(pkg.__version__)
62
68
  return operator(pkg_version, Version(target))
69
+
70
+
71
+ def get_amp_custom_fwd_bwd() -> Callable:
72
+ if compare_version("torch", operator.ge, "2.4.0"):
73
+ return (
74
+ functools.partial(torch.amp.custom_fwd, device_type="cuda"),
75
+ functools.partial(torch.amp.custom_bwd, device_type="cuda"),
76
+ )
77
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
78
+
79
+
80
+ amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
81
+
82
+
83
+ torch_to_triton_dtype = {
84
+ torch.float32: tl.float32,
85
+ torch.float16: tl.float16,
86
+ torch.bfloat16: tl.bfloat16,
87
+ }
88
+
89
+
90
+ @triton.jit
91
+ def element_mul_kernel(
92
+ X_ptr,
93
+ X_stride,
94
+ grad_output_ptr,
95
+ n_cols,
96
+ BLOCK_SIZE: tl.constexpr,
97
+ ):
98
+ """
99
+ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
100
+ The multiplication is performed in-place on the tensor pointed by X_ptr.
101
+
102
+ Parameters:
103
+ X_ptr: Pointer to the input tensor.
104
+ X_stride (int): The stride of the input tensor.
105
+ grad_output_ptr: Pointer to the gradient output value.
106
+ n_cols (int): The number of columns in the input tensor.
107
+ BLOCK_SIZE (int): The block size for Triton operations.
108
+ """
109
+
110
+ # Get the program ID and convert it to int64 to avoid overflow
111
+ program_id = tl.program_id(0).to(tl.int64)
112
+
113
+ # Locate the start index
114
+ X_ptr += program_id * X_stride
115
+
116
+ # Load the gradient output value
117
+ grad_output = tl.load(grad_output_ptr)
118
+
119
+ # Perform the element-wise multiplication
120
+ for i in range(0, n_cols, BLOCK_SIZE):
121
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
122
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
123
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
@@ -5,7 +5,9 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noq
5
5
  from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
6
  LigerFusedLinearCrossEntropyLoss,
7
7
  )
8
+ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
8
9
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10
+ from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
9
11
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
10
12
  from liger_kernel.transformers.monkey_patch import ( # noqa: F401
11
13
  _apply_liger_kernel,
@@ -15,6 +17,7 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401
15
17
  apply_liger_kernel_to_llama,
16
18
  apply_liger_kernel_to_mistral,
17
19
  apply_liger_kernel_to_mixtral,
20
+ apply_liger_kernel_to_mllama,
18
21
  apply_liger_kernel_to_phi3,
19
22
  apply_liger_kernel_to_qwen2,
20
23
  apply_liger_kernel_to_qwen2_vl,
@@ -2,7 +2,9 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
2
2
  from liger_kernel.ops.fused_linear_cross_entropy import (
3
3
  LigerFusedLinearCrossEntropyFunction,
4
4
  )
5
+ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
5
6
  from liger_kernel.ops.geglu import LigerGELUMulFunction
7
+ from liger_kernel.ops.jsd import LigerJSDFunction
6
8
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
7
9
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
8
10
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
@@ -17,3 +19,5 @@ liger_rms_norm = LigerRMSNormFunction.apply
17
19
  liger_rope = LigerRopeFunction.apply
18
20
  liger_layer_norm = LigerLayerNormFunction.apply
19
21
  liger_kl_div = LigerKLDivLossFunction.apply
22
+ liger_jsd = LigerJSDFunction.apply
23
+ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
@@ -0,0 +1,98 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
6
+
7
+
8
+ class LigerFusedLinearJSD(torch.nn.Module):
9
+ r"""Fusing the last linear layer with generalized JSD
10
+
11
+ Handle the forward and backward pass of the final linear layer via JSD by avoiding
12
+ the materialization of the large logits tensor.
13
+
14
+ Args:
15
+ jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
16
+ ignore_index (int): The index to ignore in the target. Default: `-100`
17
+ temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
18
+
19
+ Shape:
20
+ - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension.
21
+ - student_weight: :math:`(V, H)`, where V is vocab size.
22
+ - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model.
23
+ - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different.
24
+ - shift_labels: :math:`(BT,)`
25
+ - Output: a scalar.
26
+
27
+ Examples:
28
+ ```python
29
+ >>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10)
30
+ >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0)
31
+ >>> # generate inputs and weights
32
+ >>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True)
33
+ >>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda")
34
+ >>> # teacher input doesn't require grad, hidden_dim can be different from student's
35
+ >>> teacher_input = torch.rand(B * T, H_t, device="cuda")
36
+ >>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda")
37
+ >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight)
38
+ >>> output.backward()
39
+ >>>
40
+ >>> # Example with labels for supervised fine-tuning (SFT) context:
41
+ >>>
42
+ >>> # Assume hidden_states, lm_heads and corresponding labels are given
43
+ >>> student_lm_head = torch.nn.Linear(H_s, V, bias=False)
44
+ >>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1)
45
+ >>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False)
46
+ >>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1)
47
+ >>> labels = torch.randint(0, V, (B * T,), torch.long)
48
+ >>>
49
+ >>> # Shift so that tokens < n predict n
50
+ >>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous()
51
+ >>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous()
52
+ >>> shift_labels = labels[..., 1:].contiguous()
53
+ >>>
54
+ >>> # Flatten tokens
55
+ >>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V)
56
+ >>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V)
57
+ >>> shift_labels = shift_labels.view(-1)
58
+ >>>
59
+ >>> # Calculate loss
60
+ >>> loss_fct = LigerJSD(beta=0.1)
61
+ >>> loss = loss_fct(
62
+ >>> shift_studetn_hidden_states,
63
+ >>> student_lm_head.weight,
64
+ >>> shift_teacher_hidden_states,
65
+ >>> teacher_lm_head.weight,
66
+ >>> shift_labels
67
+ >>> )
68
+ ```
69
+ """
70
+
71
+ def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0):
72
+ super().__init__()
73
+ assert (
74
+ jsd_beta > 0 and jsd_beta < 1
75
+ ), f"beta must be greater than 0 and less than 1. Got: {jsd_beta}"
76
+ assert temperature != 0, "temperature cannot be 0."
77
+ self.jsd_beta = jsd_beta
78
+ self.temperature = temperature
79
+ self.ignore_index = ignore_index
80
+
81
+ def forward(
82
+ self,
83
+ student_input: torch.Tensor,
84
+ student_weight: torch.Tensor,
85
+ teacher_input: torch.Tensor,
86
+ teacher_weight: torch.Tensor,
87
+ shift_labels: Optional[torch.LongTensor],
88
+ ):
89
+ return LigerFusedLinearJSDFunction.apply(
90
+ student_input,
91
+ student_weight,
92
+ teacher_input,
93
+ teacher_weight,
94
+ shift_labels,
95
+ self.jsd_beta,
96
+ self.ignore_index,
97
+ self.temperature,
98
+ )
@@ -0,0 +1,75 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from liger_kernel.ops.jsd import LigerJSDFunction
6
+
7
+
8
+ class LigerJSD(torch.nn.Module):
9
+ r"""The generalized Jensen-Shannon Divergence.
10
+ .. math::
11
+ JSD(\beta)(P || Q)
12
+ = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
13
+ .. note::
14
+ As all the other losses in PyTorch, this function expects the first argument,
15
+ :attr:`log_q`, to be the predictions, the output of the student model in log-space,
16
+ and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space.
17
+ This differs from the standard mathematical notation :math:`JSD(P || Q)` where
18
+ :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
19
+
20
+ Args:
21
+ beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
22
+ ignore_index (int): The index to ignore in the target. Default: `-100`
23
+
24
+ Shape:
25
+ - Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size.
26
+ - Target: :math:`(BT, V)`, same shape as the input.
27
+ - shift_labels (Optional): :math:`(BT,)`
28
+ - Output: a scalar.
29
+
30
+ Examples:
31
+ ```python
32
+ >>> (B, T, V) = (2, 2, 5)
33
+ >>> jsd = LigerJSD(beta=0.1)
34
+ >>> # input should be a distribution in the log space
35
+ >>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
36
+ >>> target = torch.randn(B * T, V).log_softmax(dim=-1)
37
+ >>> output = jsd(input, target)
38
+ >>>
39
+ >>> # Example with labels for supervised fine-tuning (SFT) context
40
+ >>> # Assume logits and corresponding labels are given
41
+ >>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
42
+ >>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1)
43
+ >>> labels = torch.randint(0, V, (B * T,), torch.long)
44
+ >>> # Shift so that tokens < n predict n
45
+ >>> shift_student_logits = student_logits[..., :-1, :].contiguous()
46
+ >>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()
47
+ >>> shift_labels = labels[..., 1:].contiguous()
48
+ >>> # Flatten tokens
49
+ >>> shift_student_logits = shift_student_logits.view(-1, V)
50
+ >>> shift_teacher_logits = shift_teacher_logits.view(-1, V)
51
+ >>> shift_labels = shift_labels.view(-1)
52
+ >>> # Calculate loss
53
+ >>> loss_fct = LigerJSD(beta=0.1)
54
+ >>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels)
55
+
56
+ ```
57
+ """
58
+
59
+ def __init__(self, beta: float = 0.5, ignore_index: int = -100):
60
+ super().__init__()
61
+ assert (
62
+ beta > 0 and beta < 1
63
+ ), f"beta must be greater than 0 and less than 1. Got: {beta}"
64
+ self.beta = beta
65
+ self.ignore_index = ignore_index
66
+
67
+ def forward(
68
+ self,
69
+ log_q: torch.Tensor,
70
+ log_p: torch.Tensor,
71
+ shift_labels: Optional[torch.LongTensor] = None,
72
+ ):
73
+ return LigerJSDFunction.apply(
74
+ log_q, log_p, shift_labels, self.beta, self.ignore_index
75
+ )
@@ -22,7 +22,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
22
22
  @replace_return_docstrings(
23
23
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
24
  )
25
- def lce_forward(
25
+ def lce_forward_deprecated(
26
26
  self,
27
27
  input_ids: torch.LongTensor = None,
28
28
  attention_mask: Optional[torch.Tensor] = None,
@@ -136,3 +136,126 @@ def lce_forward(
136
136
  hidden_states=outputs.hidden_states,
137
137
  attentions=outputs.attentions,
138
138
  )
139
+
140
+
141
+ @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
142
+ @replace_return_docstrings(
143
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
144
+ )
145
+ def lce_forward(
146
+ self,
147
+ input_ids: torch.LongTensor = None,
148
+ attention_mask: Optional[torch.Tensor] = None,
149
+ position_ids: Optional[torch.LongTensor] = None,
150
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
151
+ inputs_embeds: Optional[torch.FloatTensor] = None,
152
+ labels: Optional[torch.LongTensor] = None,
153
+ use_cache: Optional[bool] = None,
154
+ output_attentions: Optional[bool] = None,
155
+ output_hidden_states: Optional[bool] = None,
156
+ return_dict: Optional[bool] = None,
157
+ cache_position: Optional[torch.LongTensor] = None,
158
+ num_logits_to_keep: int = 0,
159
+ **loss_kwargs,
160
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
161
+ r"""
162
+ Args:
163
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
164
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
165
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
166
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
167
+
168
+ num_logits_to_keep (`int`, *optional*):
169
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
170
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
171
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
172
+
173
+ Returns:
174
+
175
+ Example:
176
+
177
+ ```python
178
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
179
+
180
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
181
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
182
+
183
+ >>> prompt = "What is your favorite condiment?"
184
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
185
+
186
+ >>> # Generate
187
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
188
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
189
+ "What is your favorite condiment?"
190
+ ```"""
191
+ output_attentions = (
192
+ output_attentions
193
+ if output_attentions is not None
194
+ else self.config.output_attentions
195
+ )
196
+ output_hidden_states = (
197
+ output_hidden_states
198
+ if output_hidden_states is not None
199
+ else self.config.output_hidden_states
200
+ )
201
+ return_dict = (
202
+ return_dict if return_dict is not None else self.config.use_return_dict
203
+ )
204
+
205
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
206
+ outputs = self.model(
207
+ input_ids=input_ids,
208
+ attention_mask=attention_mask,
209
+ position_ids=position_ids,
210
+ past_key_values=past_key_values,
211
+ inputs_embeds=inputs_embeds,
212
+ use_cache=use_cache,
213
+ output_attentions=output_attentions,
214
+ output_hidden_states=output_hidden_states,
215
+ return_dict=return_dict,
216
+ cache_position=cache_position,
217
+ )
218
+
219
+ hidden_states = outputs[0]
220
+
221
+ logits = None
222
+ loss = None
223
+ # if in training mode, don't materialize logits
224
+ if self.training and (labels is not None):
225
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
226
+
227
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
228
+ shift_labels = labels[..., 1:].contiguous()
229
+
230
+ # flatten tokens
231
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
232
+ shift_labels = shift_labels.view(-1)
233
+
234
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
235
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
236
+
237
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
238
+ if reduction == "sum":
239
+ loss /= loss_kwargs["num_items_in_batch"]
240
+
241
+ else: # if in inference mode materialize logits
242
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
243
+ if labels is not None:
244
+ loss = self.loss_function(
245
+ logits=logits,
246
+ labels=labels,
247
+ vocab_size=self.config.vocab_size,
248
+ **loss_kwargs,
249
+ )
250
+
251
+ if not return_dict:
252
+ output = (logits,) + outputs[1:]
253
+ return (loss,) + output if loss is not None else output
254
+
255
+ return CausalLMOutputWithPast(
256
+ loss=loss,
257
+ logits=logits,
258
+ past_key_values=outputs.past_key_values,
259
+ hidden_states=outputs.hidden_states,
260
+ attentions=outputs.attentions,
261
+ )
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
  import torch.nn.functional as F
@@ -17,17 +17,20 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
17
17
  LigerFusedLinearCrossEntropyLoss,
18
18
  )
19
19
 
20
+ if TYPE_CHECKING:
21
+ from transformers.cache_utils import Cache
22
+
20
23
 
21
24
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
22
25
  @replace_return_docstrings(
23
26
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
27
  )
25
- def lce_forward(
28
+ def lce_forward_deprecated(
26
29
  self,
27
30
  input_ids: torch.LongTensor = None,
28
31
  attention_mask: Optional[torch.Tensor] = None,
29
32
  position_ids: Optional[torch.LongTensor] = None,
30
- past_key_values: Optional[List[torch.FloatTensor]] = None,
33
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
31
34
  inputs_embeds: Optional[torch.FloatTensor] = None,
32
35
  labels: Optional[torch.LongTensor] = None,
33
36
  use_cache: Optional[bool] = None,
@@ -120,8 +123,9 @@ def lce_forward(
120
123
  logits = torch.cat(logits, dim=-1)
121
124
  else:
122
125
  logits = self.lm_head(hidden_states)
123
- logits = logits.float()
124
126
  if labels is not None:
127
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
128
+ logits = logits.float()
125
129
  # Shift so that tokens < n predict n
126
130
  shift_logits = logits[..., :-1, :].contiguous()
127
131
  shift_labels = labels[..., 1:].contiguous()
@@ -144,3 +148,130 @@ def lce_forward(
144
148
  hidden_states=outputs.hidden_states,
145
149
  attentions=outputs.attentions,
146
150
  )
151
+
152
+
153
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
154
+ @replace_return_docstrings(
155
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
156
+ )
157
+ def lce_forward(
158
+ self,
159
+ input_ids: torch.LongTensor = None,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ position_ids: Optional[torch.LongTensor] = None,
162
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
163
+ inputs_embeds: Optional[torch.FloatTensor] = None,
164
+ labels: Optional[torch.LongTensor] = None,
165
+ use_cache: Optional[bool] = None,
166
+ output_attentions: Optional[bool] = None,
167
+ output_hidden_states: Optional[bool] = None,
168
+ return_dict: Optional[bool] = None,
169
+ cache_position: Optional[torch.LongTensor] = None,
170
+ num_logits_to_keep: int = 0,
171
+ **loss_kwargs,
172
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
173
+ r"""
174
+ Args:
175
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
176
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
177
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
178
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
179
+
180
+ num_logits_to_keep (`int`, *optional*):
181
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
182
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
183
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
184
+
185
+ Returns:
186
+
187
+ Example:
188
+
189
+ ```python
190
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
191
+
192
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
193
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
194
+
195
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
196
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
197
+
198
+ >>> # Generate
199
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
200
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
201
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
202
+ ```"""
203
+
204
+ output_attentions = (
205
+ output_attentions
206
+ if output_attentions is not None
207
+ else self.config.output_attentions
208
+ )
209
+ output_hidden_states = (
210
+ output_hidden_states
211
+ if output_hidden_states is not None
212
+ else self.config.output_hidden_states
213
+ )
214
+ return_dict = (
215
+ return_dict if return_dict is not None else self.config.use_return_dict
216
+ )
217
+
218
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
219
+ outputs = self.model(
220
+ input_ids=input_ids,
221
+ attention_mask=attention_mask,
222
+ position_ids=position_ids,
223
+ past_key_values=past_key_values,
224
+ inputs_embeds=inputs_embeds,
225
+ use_cache=use_cache,
226
+ output_attentions=output_attentions,
227
+ output_hidden_states=output_hidden_states,
228
+ return_dict=return_dict,
229
+ cache_position=cache_position,
230
+ )
231
+
232
+ hidden_states = outputs[0]
233
+
234
+ if self.config.pretraining_tp > 1:
235
+ raise Exception("Liger Kernel does not support pretraining_tp!!")
236
+
237
+ logits = None
238
+ loss = None
239
+ # if in training mode, don't materialize logits
240
+ if self.training and (labels is not None):
241
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
242
+
243
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
244
+ shift_labels = labels[..., 1:].contiguous()
245
+
246
+ # flatten tokens
247
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
248
+ shift_labels = shift_labels.view(-1)
249
+
250
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
251
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
252
+
253
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
254
+ if reduction == "sum":
255
+ loss /= loss_kwargs["num_items_in_batch"]
256
+
257
+ else: # if in inference mode materialize logits
258
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
259
+ if labels is not None:
260
+ loss = self.loss_function(
261
+ logits=logits,
262
+ labels=labels,
263
+ vocab_size=self.config.vocab_size,
264
+ **loss_kwargs,
265
+ )
266
+
267
+ if not return_dict:
268
+ output = (logits,) + outputs[1:]
269
+ return (loss,) + output if loss is not None else output
270
+
271
+ return CausalLMOutputWithPast(
272
+ loss=loss,
273
+ logits=logits,
274
+ past_key_values=outputs.past_key_values,
275
+ hidden_states=outputs.hidden_states,
276
+ attentions=outputs.attentions,
277
+ )
@@ -136,3 +136,6 @@ def lce_forward(
136
136
  hidden_states=outputs.hidden_states,
137
137
  attentions=outputs.attentions,
138
138
  )
139
+
140
+
141
+ # Note: Grad Acc is not fixed in mistral at transformer 4.46.1