liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,172 @@
1
+ from typing import Optional
2
+ from typing import Tuple
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from transformers.utils.deprecation import deprecate_kwarg
8
+
9
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
10
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
11
+ from liger_kernel.transformers.model.output_classes import LigerGlm4vMoeCausalLMOutputWithPast
12
+
13
+
14
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ pixel_values: Optional[torch.Tensor] = None,
24
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
25
+ image_grid_thw: Optional[torch.LongTensor] = None,
26
+ video_grid_thw: Optional[torch.LongTensor] = None,
27
+ rope_deltas: Optional[torch.LongTensor] = None,
28
+ cache_position: Optional[torch.LongTensor] = None,
29
+ logits_to_keep: Union[int, torch.Tensor] = 0,
30
+ skip_logits: Optional[bool] = None,
31
+ return_dict: Optional[bool] = None,
32
+ **kwargs,
33
+ ) -> Union[Tuple, LigerGlm4vMoeCausalLMOutputWithPast]:
34
+ r"""
35
+ Args:
36
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
38
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
39
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
40
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
41
+ The temporal, height and width of feature shape of each image in LLM.
42
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
43
+ The temporal, height and width of feature shape of each video in LLM.
44
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
45
+ The rope index difference between sequence length and multimodal rope.
46
+
47
+
48
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
49
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
50
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
51
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
52
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
53
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
54
+
55
+ Example:
56
+
57
+ ```python
58
+ >>> from transformers import AutoProcessor, Glm4vMoeForConditionalGeneration
59
+ >>> import torch
60
+
61
+ >>> MODEL_PATH = "zai-org/GLM-4.5V"
62
+ >>> messages = [
63
+ {
64
+ "role": "user",
65
+ "content": [
66
+ {
67
+ "type": "image",
68
+ "url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png"
69
+ },
70
+ {
71
+ "type": "text",
72
+ "text": "describe this image"
73
+ }
74
+ ],
75
+ }
76
+ ]
77
+ >>> processor = AutoProcessor.from_pretrained(MODEL_PATH)
78
+ >>> model = Glm4vMoeForConditionalGeneration.from_pretrained(
79
+ pretrained_model_name_or_path=MODEL_PATH,
80
+ dtype="auto",
81
+ device_map="auto",
82
+ )
83
+ >>> inputs = processor.apply_chat_template(
84
+ messages,
85
+ tokenize=True,
86
+ add_generation_prompt=True,
87
+ return_dict=True,
88
+ return_tensors="pt"
89
+ ).to(model.device)
90
+ >>> inputs.pop("token_type_ids", None)
91
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=8192)
92
+ >>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
93
+ ```
94
+ """
95
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
96
+
97
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
98
+ outputs = self.model(
99
+ input_ids=input_ids,
100
+ pixel_values=pixel_values,
101
+ pixel_values_videos=pixel_values_videos,
102
+ image_grid_thw=image_grid_thw,
103
+ video_grid_thw=video_grid_thw,
104
+ position_ids=position_ids,
105
+ attention_mask=attention_mask,
106
+ past_key_values=past_key_values,
107
+ inputs_embeds=inputs_embeds,
108
+ cache_position=cache_position,
109
+ **kwargs,
110
+ )
111
+
112
+ hidden_states = outputs[0]
113
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
114
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
115
+ kept_hidden_states = hidden_states[:, slice_indices, :]
116
+
117
+ shift_labels = kwargs.pop("shift_labels", None)
118
+ logits = None
119
+ loss = None
120
+ token_accuracy = None
121
+
122
+ if skip_logits and labels is None and shift_labels is None:
123
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
124
+
125
+ if skip_logits is None:
126
+ # By default, if in training mode, don't materialize logits
127
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
128
+
129
+ # Compute loss
130
+ if skip_logits:
131
+ result = LigerForCausalLMLoss(
132
+ hidden_states=kept_hidden_states,
133
+ lm_head_weight=self.lm_head.weight,
134
+ labels=labels,
135
+ shift_labels=shift_labels,
136
+ hidden_size=self.config.hidden_size,
137
+ **kwargs,
138
+ )
139
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
140
+
141
+ else:
142
+ logits = self.lm_head(kept_hidden_states)
143
+ if labels is not None or shift_labels is not None:
144
+ loss = self.loss_function(
145
+ logits=logits,
146
+ labels=labels,
147
+ shift_labels=shift_labels,
148
+ vocab_size=self.config.vocab_size,
149
+ **kwargs,
150
+ )
151
+
152
+ if not return_dict:
153
+ output = (logits,) + outputs[1:]
154
+ output = ((loss,) + output) if loss is not None else output
155
+ output = output + (token_accuracy,) if token_accuracy is not None else output
156
+ return output
157
+
158
+ # Build output kwargs and include aux_loss only if present (depends on transformers version)
159
+ output_kwargs = dict(
160
+ loss=loss,
161
+ logits=logits,
162
+ past_key_values=outputs.past_key_values,
163
+ hidden_states=outputs.hidden_states,
164
+ attentions=outputs.attentions,
165
+ rope_deltas=outputs.rope_deltas,
166
+ token_accuracy=token_accuracy,
167
+ )
168
+ if hasattr(outputs, "aux_loss"):
169
+ output_kwargs["aux_loss"] = outputs.aux_loss
170
+
171
+ # Return GLM4V MoE output with accuracy
172
+ return LigerGlm4vMoeCausalLMOutputWithPast(**output_kwargs)
@@ -0,0 +1,157 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.utils import can_return_tuple
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerInternVLCausalLMOutputWithPast
13
+
14
+
15
+ # Copied from https://github.com/huggingface/transformers/blob/d888bd435d0c0eaabaabad5b33d52af518c7187c/src/transformers/models/internvl/modeling_internvl.py#L862
16
+ @can_return_tuple
17
+ def lce_forward(
18
+ self,
19
+ input_ids: torch.LongTensor = None,
20
+ pixel_values: Optional[torch.FloatTensor] = None,
21
+ attention_mask: Optional[torch.Tensor] = None,
22
+ position_ids: Optional[torch.LongTensor] = None,
23
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
24
+ inputs_embeds: Optional[torch.FloatTensor] = None,
25
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
26
+ vision_feature_select_strategy: Optional[str] = None,
27
+ labels: Optional[torch.LongTensor] = None,
28
+ use_cache: Optional[bool] = None,
29
+ output_attentions: Optional[bool] = None,
30
+ output_hidden_states: Optional[bool] = None,
31
+ return_dict: Optional[bool] = None,
32
+ cache_position: Optional[torch.LongTensor] = None,
33
+ logits_to_keep: Union[int, torch.Tensor] = 0,
34
+ image_sizes: Optional[torch.Tensor] = None,
35
+ skip_logits: Optional[bool] = None, # Added argument for liger-kernel
36
+ **lm_kwargs, # renamed from kwargs
37
+ ) -> Union[Tuple, LigerInternVLCausalLMOutputWithPast]:
38
+ r"""
39
+ Example:
40
+
41
+ ```python
42
+ >>> import torch
43
+ >>> from transformers import AutoProcessor, AutoModelForImageTextToText
44
+
45
+ >>> torch_device = "cuda"
46
+ >>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
47
+ >>> model = AutoModelForImageTextToText.from_pretrained(
48
+ ... "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device
49
+ ... )
50
+
51
+ >>> messages = [
52
+ ... {
53
+ ... "role": "user",
54
+ ... "content": [
55
+ ... {
56
+ ... "type": "image",
57
+ ... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
58
+ ... },
59
+ ... {
60
+ ... "type": "image",
61
+ ... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
62
+ ... },
63
+ ... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
64
+ ... ],
65
+ ... },
66
+ ... ]
67
+
68
+ >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
69
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=200)
70
+ >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
71
+ The images depict the Statue of Liberty and the Golden Gate Bridge.
72
+ ```"""
73
+
74
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
+ output_hidden_states = (
76
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
77
+ )
78
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
+ vision_feature_layer = (
80
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
81
+ )
82
+ vision_feature_select_strategy = (
83
+ vision_feature_select_strategy
84
+ if vision_feature_select_strategy is not None
85
+ else self.config.vision_feature_select_strategy
86
+ )
87
+
88
+ outputs = self.model(
89
+ input_ids=input_ids,
90
+ pixel_values=pixel_values,
91
+ attention_mask=attention_mask,
92
+ position_ids=position_ids,
93
+ past_key_values=past_key_values,
94
+ inputs_embeds=inputs_embeds,
95
+ vision_feature_layer=vision_feature_layer,
96
+ vision_feature_select_strategy=vision_feature_select_strategy,
97
+ use_cache=use_cache,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ cache_position=cache_position,
102
+ image_sizes=image_sizes,
103
+ **lm_kwargs,
104
+ )
105
+
106
+ # Copied from llava.py
107
+ hidden_states = outputs[0]
108
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
109
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
110
+ kept_hidden_states = hidden_states[:, slice_indices, :]
111
+
112
+ shift_labels = lm_kwargs.pop("shift_labels", None)
113
+ logits = None
114
+ loss = None
115
+ token_accuracy = None
116
+
117
+ if skip_logits and labels is None and shift_labels is None:
118
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
119
+
120
+ if skip_logits is None:
121
+ # By default, if in training mode, don't materialize logits
122
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
123
+
124
+ if skip_logits:
125
+ result = LigerForCausalLMLoss(
126
+ hidden_states=kept_hidden_states,
127
+ lm_head_weight=self.lm_head.weight,
128
+ labels=labels,
129
+ shift_labels=shift_labels,
130
+ hidden_size=self.config.text_config.hidden_size,
131
+ **lm_kwargs,
132
+ )
133
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
134
+
135
+ else:
136
+ logits = self.lm_head(kept_hidden_states)
137
+ if labels is not None:
138
+ loss = self.loss_function(
139
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
140
+ )
141
+
142
+ if not return_dict:
143
+ output = (logits,) + outputs[1:]
144
+ output = (loss,) + output if loss is not None else output
145
+ output = output + (token_accuracy,) if token_accuracy is not None else output
146
+ return output
147
+
148
+ # Return custom output class with token_accuracy field
149
+ return LigerInternVLCausalLMOutputWithPast(
150
+ loss=loss,
151
+ logits=logits,
152
+ past_key_values=outputs.past_key_values,
153
+ hidden_states=outputs.hidden_states,
154
+ attentions=outputs.attentions,
155
+ image_hidden_states=outputs.image_hidden_states,
156
+ token_accuracy=token_accuracy,
157
+ )
@@ -1,30 +1,31 @@
1
- from typing import TYPE_CHECKING, List, Optional, Tuple, Union
1
+ from typing import TYPE_CHECKING
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Tuple
5
+ from typing import Union
2
6
 
3
7
  import torch
4
8
  import torch.nn.functional as F
9
+
10
+ from torch.distributed.fsdp import FullyShardedDataParallel
5
11
  from torch.nn import CrossEntropyLoss
6
12
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.llama.modeling_llama import (
8
- _CONFIG_FOR_DOC,
9
- LLAMA_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
15
-
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
13
+ from transformers.utils.deprecation import deprecate_kwarg
14
+
15
+ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
19
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
20
+ from liger_kernel.utils import PEFT_AVAILABLE
19
21
 
20
22
  if TYPE_CHECKING:
21
23
  from transformers.cache_utils import Cache
22
24
 
25
+ if PEFT_AVAILABLE:
26
+ from peft.utils.other import ModulesToSaveWrapper
27
+
23
28
 
24
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
25
- @replace_return_docstrings(
26
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
27
- )
28
29
  def lce_forward_deprecated(
29
30
  self,
30
31
  input_ids: torch.LongTensor = None,
@@ -38,6 +39,7 @@ def lce_forward_deprecated(
38
39
  output_hidden_states: Optional[bool] = None,
39
40
  return_dict: Optional[bool] = None,
40
41
  cache_position: Optional[torch.LongTensor] = None,
42
+ skip_logits: Optional[bool] = None,
41
43
  ) -> Union[Tuple, CausalLMOutputWithPast]:
42
44
  r"""
43
45
  Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
@@ -67,19 +69,11 @@ def lce_forward_deprecated(
67
69
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
68
70
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
69
71
  ```"""
70
- output_attentions = (
71
- output_attentions
72
- if output_attentions is not None
73
- else self.config.output_attentions
74
- )
72
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
73
  output_hidden_states = (
76
- output_hidden_states
77
- if output_hidden_states is not None
78
- else self.config.output_hidden_states
79
- )
80
- return_dict = (
81
- return_dict if return_dict is not None else self.config.use_return_dict
74
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
82
75
  )
76
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
77
 
84
78
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
85
79
  outputs = self.model(
@@ -100,7 +94,15 @@ def lce_forward_deprecated(
100
94
  loss = None
101
95
  logits = None
102
96
 
103
- if self.training and (labels is not None):
97
+ # if in training mode, don't materialize logits
98
+ if skip_logits and labels is None:
99
+ raise ValueError("skip_logits is True, but labels is None")
100
+
101
+ if skip_logits is None:
102
+ # By default, if in training mode, don't materialize logits
103
+ skip_logits = self.training and labels is not None
104
+
105
+ if skip_logits:
104
106
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
105
107
  shift_labels = labels[..., 1:].contiguous()
106
108
 
@@ -113,13 +115,8 @@ def lce_forward_deprecated(
113
115
 
114
116
  else:
115
117
  if self.config.pretraining_tp > 1:
116
- lm_head_slices = self.lm_head.weight.split(
117
- self.vocab_size // self.config.pretraining_tp, dim=0
118
- )
119
- logits = [
120
- F.linear(hidden_states, lm_head_slices[i])
121
- for i in range(self.config.pretraining_tp)
122
- ]
118
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
119
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
123
120
  logits = torch.cat(logits, dim=-1)
124
121
  else:
125
122
  logits = self.lm_head(hidden_states)
@@ -150,10 +147,7 @@ def lce_forward_deprecated(
150
147
  )
151
148
 
152
149
 
153
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
154
- @replace_return_docstrings(
155
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
156
- )
150
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
157
151
  def lce_forward(
158
152
  self,
159
153
  input_ids: torch.LongTensor = None,
@@ -167,9 +161,10 @@ def lce_forward(
167
161
  output_hidden_states: Optional[bool] = None,
168
162
  return_dict: Optional[bool] = None,
169
163
  cache_position: Optional[torch.LongTensor] = None,
170
- num_logits_to_keep: int = 0,
171
- **loss_kwargs,
172
- ) -> Union[Tuple, CausalLMOutputWithPast]:
164
+ logits_to_keep: Union[int, torch.Tensor] = 0,
165
+ skip_logits: Optional[bool] = None,
166
+ **kwargs,
167
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
173
168
  r"""
174
169
  Args:
175
170
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -177,10 +172,12 @@ def lce_forward(
177
172
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
178
173
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
179
174
 
180
- num_logits_to_keep (`int`, *optional*):
181
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
175
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
176
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
182
177
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
183
178
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
179
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
180
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
184
181
 
185
182
  Returns:
186
183
 
@@ -201,19 +198,11 @@ def lce_forward(
201
198
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
202
199
  ```"""
203
200
 
204
- output_attentions = (
205
- output_attentions
206
- if output_attentions is not None
207
- else self.config.output_attentions
208
- )
201
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
209
202
  output_hidden_states = (
210
- output_hidden_states
211
- if output_hidden_states is not None
212
- else self.config.output_hidden_states
213
- )
214
- return_dict = (
215
- return_dict if return_dict is not None else self.config.use_return_dict
203
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
216
204
  )
205
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
217
206
 
218
207
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
219
208
  outputs = self.model(
@@ -227,51 +216,111 @@ def lce_forward(
227
216
  output_hidden_states=output_hidden_states,
228
217
  return_dict=return_dict,
229
218
  cache_position=cache_position,
219
+ **kwargs,
230
220
  )
231
221
 
232
222
  hidden_states = outputs[0]
223
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
224
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
225
+ kept_hidden_states = hidden_states[:, slice_indices, :]
233
226
 
234
227
  if self.config.pretraining_tp > 1:
235
228
  raise Exception("Liger Kernel does not support pretraining_tp!!")
236
229
 
230
+ shift_labels = kwargs.pop("shift_labels", None)
237
231
  logits = None
238
232
  loss = None
239
- # if in training mode, don't materialize logits
240
- if self.training and (labels is not None):
241
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
242
-
243
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
244
- shift_labels = labels[..., 1:].contiguous()
245
-
246
- # flatten tokens
247
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
248
- shift_labels = shift_labels.view(-1)
233
+ token_accuracy = None
249
234
 
250
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
251
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
252
-
253
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
254
- if reduction == "sum":
255
- loss /= loss_kwargs["num_items_in_batch"]
256
-
257
- else: # if in inference mode materialize logits
258
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
259
- if labels is not None:
235
+ # if in training mode, don't materialize logits
236
+ if skip_logits and labels is None and shift_labels is None:
237
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
238
+
239
+ if skip_logits is None:
240
+ # By default, if in training mode, don't materialize logits
241
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
242
+
243
+ # Compute loss
244
+ if skip_logits:
245
+ result = lce_maybe_trainable_lm_head(
246
+ self,
247
+ hidden_states=kept_hidden_states,
248
+ hidden_size=self.config.hidden_size,
249
+ labels=labels,
250
+ shift_labels=shift_labels,
251
+ **kwargs,
252
+ )
253
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
254
+ else:
255
+ logits = self.lm_head(kept_hidden_states)
256
+ if labels is not None or shift_labels is not None:
260
257
  loss = self.loss_function(
261
258
  logits=logits,
262
259
  labels=labels,
260
+ shift_labels=shift_labels,
263
261
  vocab_size=self.config.vocab_size,
264
- **loss_kwargs,
262
+ **kwargs,
265
263
  )
266
264
 
267
265
  if not return_dict:
268
266
  output = (logits,) + outputs[1:]
269
- return (loss,) + output if loss is not None else output
267
+ output = ((loss,) + output) if loss is not None else output
268
+ output = output + (token_accuracy,) if token_accuracy is not None else output
269
+ return output
270
270
 
271
- return CausalLMOutputWithPast(
271
+ # Return custom output class with token_accuracy field
272
+ return LigerCausalLMOutputWithPast(
272
273
  loss=loss,
273
274
  logits=logits,
274
275
  past_key_values=outputs.past_key_values,
275
276
  hidden_states=outputs.hidden_states,
276
277
  attentions=outputs.attentions,
278
+ token_accuracy=token_accuracy,
279
+ )
280
+
281
+
282
+ def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
283
+ lm_head = self.lm_head
284
+
285
+ # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
286
+ # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
287
+ # from the unwrapped module.
288
+ # See https://huggingface.co/docs/peft/package_reference/lora for reference.
289
+ if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
290
+ lm_head = lm_head.modules_to_save.default
291
+
292
+ # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
293
+ # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
294
+ # so the module entire parameters are summoned and kept in memory during the kernel execution.
295
+ if isinstance(lm_head, FullyShardedDataParallel):
296
+ return _FSDPForwardRedirection()(
297
+ lm_head,
298
+ _liger_for_causal_lm_loss,
299
+ lm_head.module,
300
+ hidden_states,
301
+ hidden_size,
302
+ labels,
303
+ shift_labels,
304
+ **loss_kwargs,
305
+ )
306
+
307
+ # FSDP is not used so we can read the lm_head weights and call the kernel directly
308
+ return _liger_for_causal_lm_loss(
309
+ lm_head=self.lm_head,
310
+ hidden_states=hidden_states,
311
+ hidden_size=hidden_size,
312
+ labels=labels,
313
+ shift_labels=shift_labels,
314
+ **loss_kwargs,
315
+ )
316
+
317
+
318
+ def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
319
+ return LigerForCausalLMLoss(
320
+ hidden_states=hidden_states,
321
+ lm_head_weight=lm_head.weight,
322
+ labels=labels,
323
+ hidden_size=hidden_size,
324
+ shift_labels=shift_labels,
325
+ **loss_kwargs,
277
326
  )