liger-kernel 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +2 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  9. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  10. liger_kernel/chunked_loss/kto_loss.py +172 -0
  11. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  12. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  13. liger_kernel/env_report.py +5 -12
  14. liger_kernel/ops/cross_entropy.py +102 -51
  15. liger_kernel/ops/experimental/embedding.py +1 -3
  16. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  17. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  18. liger_kernel/ops/fused_linear_jsd.py +11 -29
  19. liger_kernel/ops/geglu.py +6 -17
  20. liger_kernel/ops/group_norm.py +11 -28
  21. liger_kernel/ops/jsd.py +2 -6
  22. liger_kernel/ops/kl_div.py +8 -11
  23. liger_kernel/ops/layer_norm.py +3 -5
  24. liger_kernel/ops/qwen2vl_mrope.py +21 -37
  25. liger_kernel/ops/rms_norm.py +14 -32
  26. liger_kernel/ops/rope.py +31 -33
  27. liger_kernel/ops/swiglu.py +4 -8
  28. liger_kernel/ops/utils.py +2 -0
  29. liger_kernel/transformers/__init__.py +16 -24
  30. liger_kernel/transformers/auto_model.py +6 -13
  31. liger_kernel/transformers/cross_entropy.py +4 -6
  32. liger_kernel/transformers/experimental/embedding.py +1 -3
  33. liger_kernel/transformers/functional.py +11 -7
  34. liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
  35. liger_kernel/transformers/geglu.py +1 -4
  36. liger_kernel/transformers/group_norm.py +3 -9
  37. liger_kernel/transformers/jsd.py +1 -3
  38. liger_kernel/transformers/kl_div.py +1 -3
  39. liger_kernel/transformers/layer_norm.py +3 -9
  40. liger_kernel/transformers/model/gemma.py +18 -40
  41. liger_kernel/transformers/model/gemma2.py +19 -41
  42. liger_kernel/transformers/model/llama.py +22 -48
  43. liger_kernel/transformers/model/mistral.py +14 -26
  44. liger_kernel/transformers/model/mixtral.py +24 -54
  45. liger_kernel/transformers/model/mllama.py +16 -36
  46. liger_kernel/transformers/model/phi3.py +18 -40
  47. liger_kernel/transformers/model/qwen2.py +18 -40
  48. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  49. liger_kernel/transformers/monkey_patch.py +43 -117
  50. liger_kernel/transformers/qwen2vl_mrope.py +2 -2
  51. liger_kernel/transformers/rms_norm.py +4 -4
  52. liger_kernel/transformers/rope.py +2 -2
  53. liger_kernel/transformers/swiglu.py +2 -8
  54. liger_kernel/transformers/trainer/__init__.py +1 -3
  55. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  56. liger_kernel/triton/__init__.py +1 -3
  57. liger_kernel/triton/monkey_patch.py +1 -3
  58. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
  59. liger_kernel-0.5.3.dist-info/RECORD +69 -0
  60. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
  61. liger_kernel-0.5.1.dist-info/RECORD +0 -65
  62. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
  63. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
  64. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,154 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
5
+
6
+
7
+ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
8
+ @staticmethod
9
+ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
10
+ """
11
+ Compute JSD loss (Jensen-Shannon Divergence Loss).
12
+ Args:
13
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
14
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
15
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
16
+ Returns:
17
+ torch.Tensor: Jensen-Shannon Divergence loss
18
+ """
19
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
20
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
21
+
22
+ # Compute probabilities (only required for mean calculation)
23
+ mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
24
+ log_mean_probs = mean_probs.log()
25
+
26
+ student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
27
+ teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
28
+
29
+ # JSD is the weighted average of the KL divergences
30
+ jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
31
+ return jsd_loss
32
+
33
+ @staticmethod
34
+ def forward(
35
+ ctx,
36
+ student_input: torch.Tensor,
37
+ student_weight: torch.Tensor,
38
+ teacher_input: torch.Tensor,
39
+ teacher_weight: torch.Tensor,
40
+ true_labels: torch.LongTensor,
41
+ weight_hard_loss: float = 0.5,
42
+ weight_soft_loss: float = 0.5,
43
+ beta: float = 0.5,
44
+ ignore_index: int = -100,
45
+ temperature: float = 1.0,
46
+ compiled: bool = True,
47
+ ):
48
+ """
49
+ Fused linear layer with JSD distillation loss.
50
+ Args:
51
+ student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
52
+ student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
53
+ teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
54
+ teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
55
+ true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
56
+ weight_hard_loss (float): Weight for hard loss.
57
+ weight_soft_loss (float): Weight for soft loss.
58
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
59
+ ignore_index (int): Index to ignore in loss computation
60
+ temperature (float): Temperature for softening/sharpening distributions
61
+ compiled (bool): Whether to use torch compile
62
+ Returns:
63
+ torch.Tensor: Computed loss
64
+ """
65
+ return LigerFusedLinearDistillationBase.forward(
66
+ ctx=ctx,
67
+ student_input=student_input,
68
+ student_weight=student_weight,
69
+ teacher_input=teacher_input,
70
+ teacher_weight=teacher_weight,
71
+ target=true_labels,
72
+ loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
73
+ chunk_size=1,
74
+ weight_hard_loss=weight_hard_loss,
75
+ weight_soft_loss=weight_soft_loss,
76
+ beta=beta,
77
+ ignore_index=ignore_index,
78
+ temperature=temperature,
79
+ compiled=compiled,
80
+ )
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad_output):
84
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
85
+
86
+ return (*grads, None, None, None, None, None, None, None)
87
+
88
+
89
+ class LigerFusedLinearJSDLoss(torch.nn.Module):
90
+ """
91
+ Fused linear layer with JSD distillation loss.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ weight_hard_loss: float = 0.5,
97
+ weight_soft_loss: float = 0.5,
98
+ beta: float = 0.5,
99
+ ignore_index: int = -100,
100
+ temperature: float = 1.0,
101
+ compiled: bool = True,
102
+ ):
103
+ """
104
+ Args:
105
+ weight_hard_loss (float): Weight for hard loss.
106
+ weight_soft_loss (float): Weight for soft loss.
107
+ ignore_index (int): Index to ignore in the loss
108
+ temperature (float): Temperature for softening distributions
109
+ compiled (bool): Whether to use torch compile
110
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
111
+ """
112
+ super().__init__()
113
+ assert temperature != 0, "Temperature cannot be 0."
114
+ self.weight_hard_loss = weight_hard_loss
115
+ self.weight_soft_loss = weight_soft_loss
116
+ self.ignore_index = ignore_index
117
+ self.temperature = temperature
118
+ self.compiled = compiled
119
+ self.beta = beta
120
+
121
+ def forward(
122
+ self,
123
+ student_input: torch.Tensor,
124
+ student_weight: torch.Tensor,
125
+ teacher_input: torch.Tensor,
126
+ teacher_weight: torch.Tensor,
127
+ true_labels: torch.LongTensor,
128
+ ) -> torch.Tensor:
129
+ """
130
+ Compute the JSD distillation loss.
131
+
132
+ Args:
133
+ student_input (torch.Tensor): Student input tensor
134
+ student_weight (torch.Tensor): Student weight tensor
135
+ teacher_input (torch.Tensor): Teacher input tensor
136
+ teacher_weight (torch.Tensor): Teacher weight tensor
137
+ true_labels (torch.LongTensor): Target labels tensor
138
+
139
+ Returns:
140
+ torch.Tensor: Computed loss
141
+ """
142
+ return LigerFusedLinearJSDFunction.apply(
143
+ student_input,
144
+ student_weight,
145
+ teacher_input,
146
+ teacher_weight,
147
+ true_labels,
148
+ self.weight_hard_loss,
149
+ self.weight_soft_loss,
150
+ self.beta,
151
+ self.ignore_index,
152
+ self.temperature,
153
+ self.compiled,
154
+ )
@@ -0,0 +1,172 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase
5
+
6
+
7
+ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
8
+ @staticmethod
9
+ def preference_loss_fn(
10
+ average_log_prob_chunk,
11
+ preference_labels_chunk,
12
+ full_target,
13
+ ref_average_log_prob_chunk=None,
14
+ beta=0.1,
15
+ kl=None,
16
+ ):
17
+ """
18
+ Implements the Kahneman-Tversky Optimization (KTO) loss function.
19
+ Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization"
20
+ https://arxiv.org/abs/2402.01306
21
+
22
+ KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory)
23
+ from behavioral economics, which models how humans make decisions under uncertainty.
24
+ The loss function is asymmetric, treating gains and losses differently, similar to
25
+ human decision-making patterns.
26
+
27
+ Formula:
28
+ When y is chosen:
29
+ L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
30
+ When y is rejected:
31
+ L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)]))
32
+
33
+ Where:
34
+ - σ: Sigmoid function
35
+ - β: Temperature parameter controlling the strength of the preference signal
36
+ - π(x): Policy (current model)
37
+ - π₀(x): Reference policy (reference model)
38
+ - KL(π||π₀)_y: KL divergence estimated using the rejected response y
39
+
40
+ The loss encourages the model to:
41
+ 1. Assign higher probability to chosen responses
42
+ 2. Assign lower probability to rejected responses
43
+ 3. Maintain reasonable distance from the reference model
44
+
45
+ Args:
46
+ chosen_logps: Log probabilities of chosen tokens (batch_size,)
47
+ rejected_logps: Log probabilities of rejected tokens (batch_size,)
48
+ full_target: Non chunked full target tensor
49
+ ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
50
+ ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
51
+ beta: Weight for the direct preference loss
52
+ kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
53
+ Returns:
54
+ Tuple of (loss, chosen_rewards, rejected_rewards):
55
+ - loss: The KTO loss value
56
+ - chosen_rewards: Reward signals for chosen responses (detached)
57
+ - rejected_rewards: Reward signals for rejected responses (detached)
58
+ """
59
+ logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
60
+ multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
61
+ if kl is not None:
62
+ losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
63
+ else:
64
+ losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
65
+
66
+ return losses.sum() / (full_target.shape[0])
67
+
68
+ @staticmethod
69
+ def forward(
70
+ ctx,
71
+ _input,
72
+ weight,
73
+ target,
74
+ preference_labels,
75
+ bias=None,
76
+ ref_input=None,
77
+ ref_weight=None,
78
+ ref_bias=None,
79
+ kl=None,
80
+ ignore_index=-100,
81
+ beta=0.1,
82
+ compiled=True,
83
+ use_ref_model=True,
84
+ ):
85
+ return LigerFusedLinearUnpairedPreferenceBase.forward(
86
+ ctx=ctx,
87
+ _input=_input,
88
+ weight=weight,
89
+ target=target,
90
+ preference_labels=preference_labels,
91
+ bias=bias,
92
+ loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
93
+ ignore_index=ignore_index,
94
+ beta=beta,
95
+ compiled=compiled,
96
+ use_ref_model=use_ref_model,
97
+ ref_input=ref_input,
98
+ ref_weight=ref_weight,
99
+ ref_bias=ref_bias,
100
+ kl=kl,
101
+ )
102
+
103
+ @staticmethod
104
+ def backward(ctx, *grad_output):
105
+ grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5]
106
+ return (
107
+ *grads,
108
+ None,
109
+ None,
110
+ None,
111
+ None,
112
+ None,
113
+ None,
114
+ None,
115
+ None,
116
+ None,
117
+ None,
118
+ )
119
+
120
+
121
+ class LigerFusedLinearKTOLoss(torch.nn.Module):
122
+ """
123
+ Fused linear layer with Kahneman-Tversky Optimization (KTO) loss.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ ignore_index: int = -100,
129
+ beta: float = 0.1,
130
+ compiled: bool = True,
131
+ use_ref_model: bool = False,
132
+ ):
133
+ """
134
+ Args:
135
+ ignore_index (int): Index to ignore in the loss calculation
136
+ beta (float): Temperature parameter for the KTO loss
137
+ compiled (bool): Whether to use compiled operations
138
+ use_ref_model (bool): Whether to use a reference model for the DPO loss.
139
+ """
140
+ super().__init__()
141
+ self.ignore_index = ignore_index
142
+ self.beta = beta
143
+ self.compiled = compiled
144
+ self.use_ref_model = use_ref_model
145
+
146
+ def forward(
147
+ self,
148
+ _input,
149
+ lin_weight,
150
+ target,
151
+ bias=None,
152
+ preference_labels=None,
153
+ ref_input=None,
154
+ ref_weight=None,
155
+ ref_bias=None,
156
+ kl=None,
157
+ ):
158
+ return LigerFusedLinearKTOFunction.apply(
159
+ _input,
160
+ lin_weight,
161
+ target,
162
+ preference_labels,
163
+ bias,
164
+ ref_input,
165
+ ref_weight,
166
+ ref_bias,
167
+ kl,
168
+ self.ignore_index,
169
+ self.beta,
170
+ self.compiled,
171
+ self.use_ref_model,
172
+ )
@@ -1,13 +1,10 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
10
  """
@@ -32,11 +29,10 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
32
29
  beta (float): Weight for the odds ratio loss.
33
30
  """
34
31
  log_odds = (chosen_logps - rejected_logps) - (
35
- torch.log1p(-torch.exp(chosen_logps))
36
- - torch.log1p(-torch.exp(rejected_logps))
32
+ torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
37
33
  )
38
34
  ratio = F.logsigmoid(log_odds)
39
- loss = beta * ratio.sum() / (full_target.shape[0] // 2)
35
+ loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
40
36
 
41
37
  chosen_rewards = beta * chosen_logps
42
38
  rejected_rewards = beta * rejected_logps
@@ -56,6 +52,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
56
52
  ignore_index=-100,
57
53
  beta=0.1,
58
54
  compute_nll_loss=True,
55
+ nll_target=None,
59
56
  compiled=True,
60
57
  ):
61
58
  return LigerFusedLinearPreferenceBase.forward(
@@ -68,13 +65,14 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
68
65
  ignore_index=ignore_index,
69
66
  beta=beta,
70
67
  compute_nll_loss=compute_nll_loss,
68
+ nll_target=nll_target,
71
69
  compiled=compiled,
72
70
  )
73
71
 
74
72
  @staticmethod
75
73
  def backward(ctx, *grad_output):
76
74
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
77
- return *grads, None, None, None, None
75
+ return *grads, None, None, None, None, None
78
76
 
79
77
 
80
78
  class LigerFusedLinearORPOLoss(torch.nn.Module):
@@ -100,7 +98,7 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
100
98
  self.compute_nll_loss = compute_nll_loss
101
99
  self.compiled = compiled
102
100
 
103
- def forward(self, lin_weight, _input, target, bias=None):
101
+ def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
104
102
  return LigerFusedLinearORPOFunction.apply(
105
103
  _input,
106
104
  lin_weight,
@@ -109,5 +107,6 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
109
107
  self.ignore_index,
110
108
  self.beta,
111
109
  self.compute_nll_loss,
110
+ nll_target,
112
111
  self.compiled,
113
112
  )
@@ -1,16 +1,18 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(
13
- chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
10
+ chosen_logps,
11
+ rejected_logps,
12
+ full_target,
13
+ beta=0.1,
14
+ gamma=0.5,
15
+ label_smoothing=0.0,
14
16
  ):
15
17
  """
16
18
  Paper: https://arxiv.org/pdf/2405.14734
@@ -33,10 +35,17 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
33
35
  full_target: Non chunked full target tensor
34
36
  beta (float): beta weight
35
37
  gamma (float): gemma margin term
38
+ label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
36
39
  """
37
40
  logits = beta * (chosen_logps - rejected_logps) - gamma
38
- loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
39
- return loss
41
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
42
+ full_target.shape[0] // 2
43
+ )
44
+
45
+ chosen_rewards = beta * chosen_logps
46
+ rejected_rewards = beta * rejected_logps
47
+
48
+ return loss, chosen_rewards, rejected_rewards
40
49
 
41
50
  @staticmethod
42
51
  def forward(
@@ -48,6 +57,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
48
57
  ignore_index=-100,
49
58
  beta=0.1,
50
59
  alpha=1.0,
60
+ label_smoothing=0.0,
51
61
  compute_nll_loss=False,
52
62
  compiled=True,
53
63
  gamma=0.5,
@@ -63,6 +73,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
63
73
  ignore_index=ignore_index,
64
74
  alpha=alpha,
65
75
  beta=beta,
76
+ label_smoothing=label_smoothing,
66
77
  compiled=compiled,
67
78
  gamma=gamma,
68
79
  )
@@ -70,7 +81,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
70
81
  @staticmethod
71
82
  def backward(ctx, *grad_output):
72
83
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
73
- return *grads, None, None, None, None, None, None
84
+ return *grads, None, None, None, None, None, None, None
74
85
 
75
86
 
76
87
  class LigerFusedLinearSimPOLoss(torch.nn.Module):
@@ -83,6 +94,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
83
94
  ignore_index: int = -100,
84
95
  beta: float = 0.1,
85
96
  alpha: float = 1.0,
97
+ label_smoothing: float = 0.0,
86
98
  compute_nll_loss: bool = True,
87
99
  compiled: bool = True,
88
100
  gamma: float = 0.5,
@@ -96,6 +108,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
96
108
  self.ignore_index = ignore_index
97
109
  self.beta = beta
98
110
  self.alpha = alpha
111
+ self.label_smoothing = label_smoothing
99
112
  self.compute_nll_loss = compute_nll_loss
100
113
  self.compiled = compiled
101
114
  self.gamma = gamma
@@ -109,6 +122,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
109
122
  self.ignore_index,
110
123
  self.beta,
111
124
  self.alpha,
125
+ self.label_smoothing,
112
126
  self.compute_nll_loss,
113
127
  self.compiled,
114
128
  self.gamma,
@@ -1,12 +1,13 @@
1
1
  import platform
2
2
  import sys
3
+
3
4
  from importlib.metadata import version
4
5
 
5
6
 
6
7
  def print_env_report():
7
8
  """
8
9
 
9
- Prints a report of the environment. Useful for debugging and reproducibility.
10
+ Prints a report of the environment. Useful for debugging and reproducibility.
10
11
  Usage:
11
12
  ```
12
13
  python -m liger_kernel.env_report
@@ -27,15 +28,9 @@ def print_env_report():
27
28
  import torch
28
29
 
29
30
  print(f"PyTorch version: {torch.__version__}")
30
- cuda_version = (
31
- torch.version.cuda if torch.cuda.is_available() else "Not available"
32
- )
31
+ cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
33
32
  print(f"CUDA version: {cuda_version}")
34
- hip_version = (
35
- torch.version.hip
36
- if torch.cuda.is_available() and torch.version.hip
37
- else "Not available"
38
- )
33
+ hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
39
34
  print(f"HIP(ROCm) version: {hip_version}")
40
35
 
41
36
  except ImportError:
@@ -58,9 +53,7 @@ def print_env_report():
58
53
  print("Transformers: Not installed")
59
54
 
60
55
  try:
61
- xpu_version = (
62
- torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
63
- )
56
+ xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
64
57
  print(f"XPU version: {xpu_version}")
65
58
  except ImportError:
66
59
  print("XPU version: Unable to query")