liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,165 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
5
+
6
+
7
+ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
8
+ @staticmethod
9
+ def preference_loss_fn(
10
+ chosen_logps,
11
+ rejected_logps,
12
+ full_target,
13
+ beta=0.1,
14
+ gamma=0.5,
15
+ label_smoothing=0.0,
16
+ ):
17
+ """
18
+ Paper: https://arxiv.org/pdf/2405.14734
19
+
20
+ Formula:
21
+ L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
22
+
23
+ Where:
24
+ - π_θ(y|x): Policy (model) probability
25
+ - y_w: Chosen sequence
26
+ - y_l: Rejected sequence
27
+ - |y_w|, |y_l|: Sequence lengths
28
+ - σ: Sigmoid function
29
+ - β: beta weight
30
+ - γ: gemma margin term
31
+
32
+ Args:
33
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
34
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
35
+ full_target: Non chunked full target tensor
36
+ beta (float): beta weight
37
+ gamma (float): gemma margin term
38
+ label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
39
+ """
40
+ logits = beta * (chosen_logps - rejected_logps) - gamma
41
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
42
+ full_target.shape[0] // 2
43
+ )
44
+
45
+ chosen_rewards = beta * chosen_logps
46
+ rejected_rewards = beta * rejected_logps
47
+
48
+ return loss, chosen_rewards, rejected_rewards
49
+
50
+ @classmethod
51
+ def forward(
52
+ cls,
53
+ ctx,
54
+ _input,
55
+ weight,
56
+ target,
57
+ bias=None,
58
+ ignore_index=-100,
59
+ beta=0.1,
60
+ alpha=1.0,
61
+ label_smoothing=0.0,
62
+ compute_nll_loss=False,
63
+ compiled=True,
64
+ gamma=0.5,
65
+ chunk_size=1,
66
+ ):
67
+ """
68
+ Fused linear layer with SimPO loss.
69
+ Args:
70
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
71
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
72
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
73
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
74
+ ignore_index (int): Index to ignore in loss computation
75
+ beta (float): Weight for the odds ratio loss
76
+ alpha (float): Weight for the alpha parameter
77
+ label_smoothing (float): Label smoothing factor
78
+ compute_nll_loss (bool): Whether to compute the NLL loss
79
+ compiled (bool): Whether to use torch compile
80
+ gamma (float): Weight for the gamma parameter
81
+ chunk_size (int): Size of chunks for processing
82
+ Returns:
83
+ torch.Tensor: Computed loss
84
+ """
85
+ return super().forward(
86
+ cls=cls,
87
+ ctx=ctx,
88
+ _input=_input,
89
+ weight=weight,
90
+ target=target,
91
+ bias=bias,
92
+ ignore_index=ignore_index,
93
+ alpha=alpha,
94
+ beta=beta,
95
+ label_smoothing=label_smoothing,
96
+ compute_nll_loss=compute_nll_loss,
97
+ compiled=compiled,
98
+ gamma=gamma,
99
+ chunk_size=chunk_size,
100
+ )
101
+
102
+ @staticmethod
103
+ def backward(ctx, *grad_output):
104
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
105
+ return *grads, None, None, None, None, None, None, None, None
106
+
107
+
108
+ class LigerFusedLinearSimPOLoss(torch.nn.Module):
109
+ """
110
+ Fused linear layer with SimPO loss.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ ignore_index: int = -100,
116
+ beta: float = 0.1,
117
+ alpha: float = 1.0,
118
+ label_smoothing: float = 0.0,
119
+ compute_nll_loss: bool = True,
120
+ compiled: bool = True,
121
+ gamma: float = 0.5,
122
+ chunk_size: int = 1,
123
+ ):
124
+ """
125
+ Args:
126
+ ignore_index (int): Index to ignore in the loss.
127
+ beta (float): Weight for the odds ratio loss.
128
+ alpha (float): Weight for the alpha parameter.
129
+ label_smoothing (float): Label smoothing factor.
130
+ compute_nll_loss (bool): Whether to compute the NLL loss.
131
+ compiled (bool): Whether to use the torch compiled kernel.
132
+ gamma (float): Weight for the gamma parameter.
133
+ chunk_size (int): Size of chunks for processing.
134
+ """
135
+ super().__init__()
136
+ self.ignore_index = ignore_index
137
+ self.beta = beta
138
+ self.alpha = alpha
139
+ self.label_smoothing = label_smoothing
140
+ self.compute_nll_loss = compute_nll_loss
141
+ self.compiled = compiled
142
+ self.gamma = gamma
143
+ self.chunk_size = chunk_size
144
+
145
+ def forward(
146
+ self,
147
+ lin_weight,
148
+ _input,
149
+ target,
150
+ bias=None,
151
+ ):
152
+ return LigerFusedLinearSimPOFunction.apply(
153
+ _input,
154
+ lin_weight,
155
+ target,
156
+ bias,
157
+ self.ignore_index,
158
+ self.beta,
159
+ self.alpha,
160
+ self.label_smoothing,
161
+ self.compute_nll_loss,
162
+ self.compiled,
163
+ self.gamma,
164
+ self.chunk_size,
165
+ )
@@ -1,31 +1,42 @@
1
1
  import platform
2
2
  import sys
3
3
 
4
+ from importlib.metadata import version
5
+
4
6
 
5
7
  def print_env_report():
6
8
  """
7
- Prints a report of the environment. Useful for debugging and reproducibility.
9
+
10
+ Prints a report of the environment. Useful for debugging and reproducibility.
8
11
  Usage:
9
12
  ```
10
13
  python -m liger_kernel.env_report
11
14
  ```
15
+
12
16
  """
13
17
  print("Environment Report:")
14
18
  print("-------------------")
15
19
  print(f"Operating System: {platform.platform()}")
16
20
  print(f"Python version: {sys.version.split()[0]}")
17
21
 
22
+ try:
23
+ print(f"Liger Kernel version: {version('liger-kernel')}")
24
+ except ImportError:
25
+ print("Liger Kernel: Not installed")
26
+
18
27
  try:
19
28
  import torch
20
29
 
21
30
  print(f"PyTorch version: {torch.__version__}")
22
- cuda_version = (
23
- torch.version.cuda if torch.cuda.is_available() else "Not available"
24
- )
31
+ cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
25
32
  print(f"CUDA version: {cuda_version}")
33
+ hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
34
+ print(f"HIP(ROCm) version: {hip_version}")
35
+
26
36
  except ImportError:
27
37
  print("PyTorch: Not installed")
28
38
  print("CUDA version: Unable to query")
39
+ print("HIP(ROCm) version: Unable to query")
29
40
 
30
41
  try:
31
42
  import triton
@@ -41,6 +52,12 @@ def print_env_report():
41
52
  except ImportError:
42
53
  print("Transformers: Not installed")
43
54
 
55
+ try:
56
+ xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
57
+ print(f"XPU version: {xpu_version}")
58
+ except ImportError:
59
+ print("XPU version: Unable to query")
60
+
44
61
 
45
62
  if __name__ == "__main__":
46
63
  print_env_report()