liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  6. liger_kernel/ops/geglu.py +1 -1
  7. liger_kernel/ops/multi_token_attention.py +207 -0
  8. liger_kernel/ops/rms_norm.py +265 -54
  9. liger_kernel/ops/softmax.py +201 -0
  10. liger_kernel/ops/sparsemax.py +62 -50
  11. liger_kernel/ops/swiglu.py +1 -1
  12. liger_kernel/transformers/__init__.py +3 -0
  13. liger_kernel/transformers/functional.py +62 -0
  14. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  15. liger_kernel/transformers/model/gemma.py +25 -8
  16. liger_kernel/transformers/model/gemma2.py +27 -8
  17. liger_kernel/transformers/model/gemma3.py +62 -98
  18. liger_kernel/transformers/model/glm4.py +16 -7
  19. liger_kernel/transformers/model/llama.py +25 -7
  20. liger_kernel/transformers/model/llama4.py +108 -0
  21. liger_kernel/transformers/model/llava.py +95 -124
  22. liger_kernel/transformers/model/mistral.py +13 -8
  23. liger_kernel/transformers/model/mixtral.py +16 -7
  24. liger_kernel/transformers/model/mllama.py +16 -7
  25. liger_kernel/transformers/model/olmo2.py +16 -7
  26. liger_kernel/transformers/model/paligemma.py +8 -1
  27. liger_kernel/transformers/model/phi3.py +25 -8
  28. liger_kernel/transformers/model/qwen2.py +24 -7
  29. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  30. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  31. liger_kernel/transformers/model/qwen3.py +11 -3
  32. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  33. liger_kernel/transformers/monkey_patch.py +304 -70
  34. liger_kernel/transformers/multi_token_attention.py +64 -0
  35. liger_kernel/transformers/rms_norm.py +40 -4
  36. liger_kernel/transformers/softmax.py +12 -0
  37. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
  38. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
  39. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  40. liger_kernel/transformers/gema3_rms.py +0 -8
  41. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
@@ -0,0 +1,127 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
5
+
6
+
7
+ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
8
+ @staticmethod
9
+ def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
10
+ """
11
+ Compute Cosine loss (Cosine Similarity Loss).
12
+ Args:
13
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
14
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
15
+ beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
16
+ Returns:
17
+ torch.Tensor: cosine similarity loss
18
+ """
19
+ student_norm = F.normalize(student_logits, p=2, dim=-1)
20
+ teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
21
+
22
+ cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
23
+ loss = beta * (1 - cosine_sim)
24
+ return loss.sum()
25
+
26
+ @classmethod
27
+ def forward(
28
+ cls,
29
+ ctx,
30
+ student_input: torch.Tensor,
31
+ student_weight: torch.Tensor,
32
+ teacher_input: torch.Tensor,
33
+ teacher_weight: torch.Tensor,
34
+ true_labels: torch.LongTensor,
35
+ student_bias: torch.Tensor,
36
+ teacher_bias: torch.Tensor,
37
+ weight_hard_loss: float = 0.5,
38
+ weight_soft_loss: float = 0.5,
39
+ beta: float = 0.5,
40
+ ignore_index: int = -100,
41
+ temperature: float = 1.0,
42
+ compiled: bool = True,
43
+ chunk_size: int = 1024,
44
+ ):
45
+ return super().forward(
46
+ cls=cls,
47
+ ctx=ctx,
48
+ student_input=student_input,
49
+ student_weight=student_weight,
50
+ teacher_input=teacher_input,
51
+ teacher_weight=teacher_weight,
52
+ target=true_labels,
53
+ student_bias=student_bias,
54
+ teacher_bias=teacher_bias,
55
+ chunk_size=chunk_size,
56
+ weight_hard_loss=weight_hard_loss,
57
+ weight_soft_loss=weight_soft_loss,
58
+ beta=beta,
59
+ ignore_index=ignore_index,
60
+ temperature=temperature,
61
+ compiled=compiled,
62
+ )
63
+
64
+ @staticmethod
65
+ def backward(ctx, grad_output):
66
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
67
+
68
+ return (
69
+ *grads,
70
+ None, # teacher_bias
71
+ None, # weight_hard_loss
72
+ None, # weight_soft_loss
73
+ None, # beta
74
+ None, # ignore_index
75
+ None, # temperature
76
+ None, # compiled
77
+ None, # chunk_size
78
+ )
79
+
80
+
81
+ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
82
+ def __init__(
83
+ self,
84
+ weight_hard_loss: float = 0.5,
85
+ weight_soft_loss: float = 0.5,
86
+ beta: float = 0.5,
87
+ ignore_index: int = -100,
88
+ temperature: float = 1.0,
89
+ compiled: bool = True,
90
+ chunk_size: int = 1024,
91
+ ):
92
+ super().__init__()
93
+ assert temperature != 0, "Temperature cannot be 0."
94
+ self.weight_hard_loss = weight_hard_loss
95
+ self.weight_soft_loss = weight_soft_loss
96
+ self.ignore_index = ignore_index
97
+ self.temperature = temperature
98
+ self.compiled = compiled
99
+ self.beta = beta
100
+ self.chunk_size = chunk_size
101
+
102
+ def forward(
103
+ self,
104
+ student_input: torch.Tensor,
105
+ student_weight: torch.Tensor,
106
+ teacher_input: torch.Tensor,
107
+ teacher_weight: torch.Tensor,
108
+ true_labels: torch.LongTensor,
109
+ student_bias: torch.Tensor = None,
110
+ teacher_bias: torch.Tensor = None,
111
+ ) -> torch.Tensor:
112
+ return LigerFusedLinearCosineSimilarityFunction.apply(
113
+ student_input,
114
+ student_weight,
115
+ teacher_input,
116
+ teacher_weight,
117
+ true_labels,
118
+ student_bias,
119
+ teacher_bias,
120
+ self.weight_hard_loss,
121
+ self.weight_soft_loss,
122
+ self.beta,
123
+ self.ignore_index,
124
+ self.temperature,
125
+ self.compiled,
126
+ self.chunk_size,
127
+ )
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
9
10
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
10
11
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
11
12
  liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
13
+ liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
12
14
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
13
15
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
14
16
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
liger_kernel/ops/dyt.py CHANGED
@@ -4,8 +4,6 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from triton.language.extra.libdevice import tanh
8
-
9
7
  from liger_kernel.ops.utils import compare_version
10
8
  from liger_kernel.ops.utils import ensure_contiguous
11
9
  from liger_kernel.ops.utils import infer_device