liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.utils.deprecation import deprecate_kwarg
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
+
14
+
15
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
16
+ def lce_forward(
17
+ self,
18
+ input_ids: torch.LongTensor = None,
19
+ attention_mask: Optional[torch.Tensor] = None,
20
+ position_ids: Optional[torch.LongTensor] = None,
21
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
22
+ inputs_embeds: Optional[torch.FloatTensor] = None,
23
+ labels: Optional[torch.LongTensor] = None,
24
+ use_cache: Optional[bool] = None,
25
+ output_attentions: Optional[bool] = None,
26
+ output_hidden_states: Optional[bool] = None,
27
+ return_dict: Optional[bool] = None,
28
+ cache_position: Optional[torch.LongTensor] = None,
29
+ logits_to_keep: Union[int, torch.Tensor] = 0,
30
+ skip_logits: Optional[bool] = None,
31
+ **kwargs,
32
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
33
+ r"""
34
+ Args:
35
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
36
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
37
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
38
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
39
+
40
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
41
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
42
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
43
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
44
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
45
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
46
+
47
+ Returns:
48
+
49
+ Example:
50
+
51
+ ```python
52
+ >>> from transformers import AutoTokenizer, Olmo2ForCausalLM
53
+
54
+ >>> model = Olmo2ForCausalLM.from_pretrained("allenai/Olmo2-1B-hf")
55
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo2-1B-hf")
56
+
57
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
58
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
59
+
60
+ >>> # Generate
61
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
62
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
63
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
64
+ ```
65
+ """
66
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
67
+ output_hidden_states = (
68
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
69
+ )
70
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
71
+
72
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
73
+ outputs = self.model(
74
+ input_ids=input_ids,
75
+ attention_mask=attention_mask,
76
+ position_ids=position_ids,
77
+ past_key_values=past_key_values,
78
+ inputs_embeds=inputs_embeds,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict,
83
+ cache_position=cache_position,
84
+ **kwargs,
85
+ )
86
+
87
+ hidden_states = outputs[0]
88
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
89
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
90
+ kept_hidden_states = hidden_states[:, slice_indices, :]
91
+
92
+ shift_labels = kwargs.pop("shift_labels", None)
93
+ logits = None
94
+ loss = None
95
+ token_accuracy = None
96
+
97
+ if skip_logits and labels is None and shift_labels is None:
98
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
99
+
100
+ if skip_logits is None:
101
+ # By default, if in training mode, don't materialize logits
102
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
103
+
104
+ # Compute loss
105
+ if skip_logits:
106
+ result = LigerForCausalLMLoss(
107
+ hidden_states=kept_hidden_states,
108
+ lm_head_weight=self.lm_head.weight,
109
+ labels=labels,
110
+ shift_labels=shift_labels,
111
+ hidden_size=self.config.hidden_size,
112
+ **kwargs,
113
+ )
114
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
115
+
116
+ else:
117
+ logits = self.lm_head(kept_hidden_states)
118
+ if labels is not None or shift_labels is not None:
119
+ loss = self.loss_function(
120
+ logits=logits,
121
+ labels=labels,
122
+ shift_labels=shift_labels,
123
+ vocab_size=self.config.vocab_size,
124
+ **kwargs,
125
+ )
126
+
127
+ if not return_dict:
128
+ output = (logits,) + outputs[1:]
129
+ output = ((loss,) + output) if loss is not None else output
130
+ output = output + (token_accuracy,) if token_accuracy is not None else output
131
+ return output
132
+
133
+ # Return custom output class with token_accuracy field
134
+ return LigerCausalLMOutputWithPast(
135
+ loss=loss,
136
+ logits=logits,
137
+ past_key_values=outputs.past_key_values,
138
+ hidden_states=outputs.hidden_states,
139
+ attentions=outputs.attentions,
140
+ token_accuracy=token_accuracy,
141
+ )
@@ -0,0 +1,147 @@
1
+ """
2
+ Custom output classes for Liger-Kernel that extend transformers' ModelOutput classes
3
+ with optional token accuracy field.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import torch
10
+
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast
13
+
14
+ # The following model-specific outputs are optional and depend on the installed
15
+ # transformers version. Guard their imports so our module remains importable
16
+ # even when those models are not available in the environment.
17
+ try:
18
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast as _Gemma3CausalLMOutputWithPast
19
+ except Exception:
20
+ _Gemma3CausalLMOutputWithPast = None
21
+
22
+ try:
23
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import (
24
+ Glm4vMoeCausalLMOutputWithPast as _Glm4vMoeCausalLMOutputWithPast,
25
+ )
26
+ except Exception:
27
+ _Glm4vMoeCausalLMOutputWithPast = None
28
+
29
+ try:
30
+ from transformers.models.internvl.modeling_internvl import (
31
+ InternVLCausalLMOutputWithPast as _InternVLCausalLMOutputWithPast,
32
+ )
33
+ except Exception:
34
+ _InternVLCausalLMOutputWithPast = None
35
+
36
+ try:
37
+ from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast as _LlavaCausalLMOutputWithPast
38
+ except Exception:
39
+ _LlavaCausalLMOutputWithPast = None
40
+
41
+ try:
42
+ from transformers.models.paligemma.modeling_paligemma import (
43
+ PaliGemmaCausalLMOutputWithPast as _PaliGemmaCausalLMOutputWithPast,
44
+ )
45
+ except Exception:
46
+ _PaliGemmaCausalLMOutputWithPast = None
47
+
48
+ try:
49
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
50
+ Qwen2_5_VLCausalLMOutputWithPast as _Qwen2_5_VLCausalLMOutputWithPast,
51
+ )
52
+ except Exception:
53
+ _Qwen2_5_VLCausalLMOutputWithPast = None
54
+
55
+ try:
56
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
57
+ Qwen2VLCausalLMOutputWithPast as _Qwen2VLCausalLMOutputWithPast,
58
+ )
59
+ except Exception:
60
+ _Qwen2VLCausalLMOutputWithPast = None
61
+
62
+ try:
63
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import (
64
+ Qwen3VLCausalLMOutputWithPast as _Qwen3VLCausalLMOutputWithPast,
65
+ )
66
+ except Exception:
67
+ _Qwen3VLCausalLMOutputWithPast = None
68
+
69
+ try:
70
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
71
+ Qwen3VLMoeCausalLMOutputWithPast as _Qwen3VLMoeCausalLMOutputWithPast,
72
+ )
73
+ except Exception:
74
+ _Qwen3VLMoeCausalLMOutputWithPast = None
75
+
76
+
77
+ @dataclass
78
+ class LigerCausalLMOutputWithPast(CausalLMOutputWithPast):
79
+ token_accuracy: Optional[torch.FloatTensor] = None
80
+
81
+
82
+ @dataclass
83
+ class LigerMoeCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
84
+ token_accuracy: Optional[torch.FloatTensor] = None
85
+
86
+
87
+ if _Gemma3CausalLMOutputWithPast is not None:
88
+
89
+ @dataclass
90
+ class LigerGemma3CausalLMOutputWithPast(_Gemma3CausalLMOutputWithPast):
91
+ token_accuracy: Optional[torch.FloatTensor] = None
92
+
93
+
94
+ if _Glm4vMoeCausalLMOutputWithPast is not None:
95
+
96
+ @dataclass
97
+ class LigerGlm4vMoeCausalLMOutputWithPast(_Glm4vMoeCausalLMOutputWithPast):
98
+ token_accuracy: Optional[torch.FloatTensor] = None
99
+
100
+
101
+ if _LlavaCausalLMOutputWithPast is not None:
102
+
103
+ @dataclass
104
+ class LigerLlavaCausalLMOutputWithPast(_LlavaCausalLMOutputWithPast):
105
+ token_accuracy: Optional[torch.FloatTensor] = None
106
+
107
+
108
+ if _InternVLCausalLMOutputWithPast is not None:
109
+
110
+ @dataclass
111
+ class LigerInternVLCausalLMOutputWithPast(_InternVLCausalLMOutputWithPast):
112
+ token_accuracy: Optional[torch.FloatTensor] = None
113
+
114
+
115
+ if _PaliGemmaCausalLMOutputWithPast is not None:
116
+
117
+ @dataclass
118
+ class LigerPaliGemmaCausalLMOutputWithPast(_PaliGemmaCausalLMOutputWithPast):
119
+ token_accuracy: Optional[torch.FloatTensor] = None
120
+
121
+
122
+ if _Qwen2_5_VLCausalLMOutputWithPast is not None:
123
+
124
+ @dataclass
125
+ class LigerQwen2_5_VLCausalLMOutputWithPast(_Qwen2_5_VLCausalLMOutputWithPast):
126
+ token_accuracy: Optional[torch.FloatTensor] = None
127
+
128
+
129
+ if _Qwen2VLCausalLMOutputWithPast is not None:
130
+
131
+ @dataclass
132
+ class LigerQwen2VLCausalLMOutputWithPast(_Qwen2VLCausalLMOutputWithPast):
133
+ token_accuracy: Optional[torch.FloatTensor] = None
134
+
135
+
136
+ if _Qwen3VLCausalLMOutputWithPast is not None:
137
+
138
+ @dataclass
139
+ class LigerQwen3VLCausalLMOutputWithPast(_Qwen3VLCausalLMOutputWithPast):
140
+ token_accuracy: Optional[torch.FloatTensor] = None
141
+
142
+
143
+ if _Qwen3VLMoeCausalLMOutputWithPast is not None:
144
+
145
+ @dataclass
146
+ class LigerQwen3VLMoeCausalLMOutputWithPast(_Qwen3VLMoeCausalLMOutputWithPast):
147
+ token_accuracy: Optional[torch.FloatTensor] = None