liger-kernel-nightly 0.4.2.dev20241119223206__tar.gz → 0.4.2.dev20241121054604__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {liger_kernel_nightly-0.4.2.dev20241119223206/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241121054604}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/pyproject.toml +1 -1
  3. liger_kernel_nightly-0.4.2.dev20241121054604/src/liger_kernel/chunked_loss/__init__.py +4 -0
  4. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/chunked_loss/cpo_loss.py +41 -1
  5. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/chunked_loss/dpo_loss.py +38 -1
  6. liger_kernel_nightly-0.4.2.dev20241121054604/src/liger_kernel/chunked_loss/functional.py +9 -0
  7. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/chunked_loss/orpo_loss.py +38 -2
  8. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/chunked_loss/simpo_loss.py +43 -0
  9. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  10. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
  11. liger_kernel_nightly-0.4.2.dev20241119223206/src/liger_kernel/transformers/model/__init__.py +0 -0
  12. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/LICENSE +0 -0
  13. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/NOTICE +0 -0
  14. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/README.md +0 -0
  15. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/setup.cfg +0 -0
  16. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/chunked_loss/fused_linear_preference.py +1 -1
  17. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/env_report.py +0 -0
  18. {liger_kernel_nightly-0.4.2.dev20241119223206/src/liger_kernel/chunked_loss → liger_kernel_nightly-0.4.2.dev20241121054604/src/liger_kernel/ops}/__init__.py +0 -0
  19. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/cross_entropy.py +0 -0
  20. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  21. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  22. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  23. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  24. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/geglu.py +0 -0
  25. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/group_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/jsd.py +0 -0
  27. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/kl_div.py +0 -0
  28. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/layer_norm.py +0 -0
  29. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  30. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/rms_norm.py +0 -0
  31. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/rope.py +0 -0
  32. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/swiglu.py +0 -0
  33. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/ops/utils.py +0 -0
  34. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/__init__.py +0 -0
  35. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/auto_model.py +0 -0
  36. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  38. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/functional.py +0 -0
  39. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  40. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  41. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/geglu.py +0 -0
  42. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/group_norm.py +0 -0
  43. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/jsd.py +0 -0
  44. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/kl_div.py +0 -0
  45. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/layer_norm.py +0 -0
  46. {liger_kernel_nightly-0.4.2.dev20241119223206/src/liger_kernel/ops → liger_kernel_nightly-0.4.2.dev20241121054604/src/liger_kernel/transformers/model}/__init__.py +0 -0
  47. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/model/gemma.py +0 -0
  48. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  49. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/model/llama.py +0 -0
  50. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/model/mistral.py +0 -0
  51. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  52. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/model/mllama.py +0 -0
  53. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/model/phi3.py +0 -0
  54. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  55. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  56. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  57. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  58. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/rms_norm.py +0 -0
  59. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/rope.py +0 -0
  60. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/swiglu.py +0 -0
  61. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  62. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/triton/__init__.py +0 -0
  63. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel/triton/monkey_patch.py +0 -0
  64. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  65. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  66. {liger_kernel_nightly-0.4.2.dev20241119223206 → liger_kernel_nightly-0.4.2.dev20241121054604}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241119223206
3
+ Version: 0.4.2.dev20241121054604
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.2.dev20241119223206"
7
+ version = "0.4.2.dev20241121054604"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,4 @@
1
+ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
+ from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
4
+ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
@@ -1,3 +1,4 @@
1
+ import torch
1
2
  import torch.nn.functional as F
2
3
 
3
4
  from liger_kernel.chunked_loss.fused_linear_preference import (
@@ -46,10 +47,10 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
46
47
  target,
47
48
  bias,
48
49
  loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
49
- compute_nll_loss=compute_nll_loss,
50
50
  ignore_index=ignore_index,
51
51
  alpha=alpha,
52
52
  beta=beta,
53
+ compute_nll_loss=compute_nll_loss,
53
54
  compiled=compiled,
54
55
  )
55
56
 
@@ -59,3 +60,42 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
59
60
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
60
61
  # Return these gradients, followed by None for the remaining inputs
61
62
  return *grads, None, None, None, None, None
63
+
64
+
65
+ class LigerFusedLinearCPOLoss(torch.nn.Module):
66
+ """
67
+ Fused linear layer with CPO loss.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ ignore_index: int = -100,
73
+ beta: float = 0.1,
74
+ alpha: float = 1.0,
75
+ compute_nll_loss: bool = True,
76
+ compiled: bool = True,
77
+ ):
78
+ """
79
+ Args:
80
+ ignore_index (int): Index to ignore in the loss.
81
+ beta (float): Weight for the odds ratio loss.
82
+ """
83
+ super().__init__()
84
+ self.ignore_index = ignore_index
85
+ self.beta = beta
86
+ self.alpha = alpha
87
+ self.compute_nll_loss = compute_nll_loss
88
+ self.compiled = compiled
89
+
90
+ def forward(self, lin_weight, _input, target, bias=None):
91
+ return LigerFusedLinearCPOFunction.apply(
92
+ _input,
93
+ lin_weight,
94
+ target,
95
+ bias,
96
+ self.ignore_index,
97
+ self.beta,
98
+ self.alpha,
99
+ self.compute_nll_loss,
100
+ self.compiled,
101
+ )
@@ -1,3 +1,4 @@
1
+ import torch
1
2
  import torch.nn.functional as F
2
3
 
3
4
  from liger_kernel.chunked_loss.fused_linear_preference import (
@@ -43,9 +44,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
43
44
  target=target,
44
45
  bias=bias,
45
46
  loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
46
- compute_nll_loss=compute_nll_loss,
47
47
  ignore_index=ignore_index,
48
48
  beta=beta,
49
+ compute_nll_loss=compute_nll_loss,
49
50
  compiled=compiled,
50
51
  )
51
52
 
@@ -55,3 +56,39 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
55
56
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
56
57
  # Return these gradients, followed by None for the remaining inputs
57
58
  return *grads, None, None, None, None
59
+
60
+
61
+ class LigerFusedLinearDPOLoss(torch.nn.Module):
62
+ """
63
+ Fused linear layer with DPO loss.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ ignore_index: int = -100,
69
+ beta: float = 0.1,
70
+ compute_nll_loss: bool = True,
71
+ compiled: bool = True,
72
+ ):
73
+ """
74
+ Args:
75
+ ignore_index (int): Index to ignore in the loss.
76
+ beta (float): Weight for the odds ratio loss.
77
+ """
78
+ super().__init__()
79
+ self.ignore_index = ignore_index
80
+ self.beta = beta
81
+ self.compute_nll_loss = compute_nll_loss
82
+ self.compiled = compiled
83
+
84
+ def forward(self, lin_weight, _input, target, bias=None):
85
+ return LigerFusedLinearDPOFunction.apply(
86
+ _input,
87
+ lin_weight,
88
+ target,
89
+ bias,
90
+ self.ignore_index,
91
+ self.beta,
92
+ self.compute_nll_loss,
93
+ self.compiled,
94
+ )
@@ -0,0 +1,9 @@
1
+ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
+ from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
4
+ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
5
+
6
+ liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
7
+ liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
8
+ liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
9
+ liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
@@ -34,7 +34,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
34
34
  ignore_index=-100,
35
35
  beta=0.1,
36
36
  compute_nll_loss=True,
37
- compiled=False,
37
+ compiled=True,
38
38
  ):
39
39
  """
40
40
  Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
@@ -49,9 +49,9 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
49
49
  target=target,
50
50
  bias=bias,
51
51
  loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
52
- compute_nll_loss=compute_nll_loss,
53
52
  ignore_index=ignore_index,
54
53
  beta=beta,
54
+ compute_nll_loss=compute_nll_loss,
55
55
  compiled=compiled,
56
56
  )
57
57
 
@@ -61,3 +61,39 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
61
61
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
62
62
  # Return these gradients, followed by None for the remaining inputs
63
63
  return *grads, None, None, None, None
64
+
65
+
66
+ class LigerFusedLinearORPOLoss(torch.nn.Module):
67
+ """
68
+ Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ ignore_index: int = -100,
74
+ beta: float = 0.1,
75
+ compute_nll_loss: bool = True,
76
+ compiled: bool = True,
77
+ ):
78
+ """
79
+ Args:
80
+ ignore_index (int): Index to ignore in the loss.
81
+ beta (float): Weight for the odds ratio loss.
82
+ """
83
+ super().__init__()
84
+ self.ignore_index = ignore_index
85
+ self.beta = beta
86
+ self.compute_nll_loss = compute_nll_loss
87
+ self.compiled = compiled
88
+
89
+ def forward(self, lin_weight, _input, target, bias=None):
90
+ return LigerFusedLinearORPOFunction.apply(
91
+ _input,
92
+ lin_weight,
93
+ target,
94
+ bias,
95
+ self.ignore_index,
96
+ self.beta,
97
+ self.compute_nll_loss,
98
+ self.compiled,
99
+ )
@@ -1,3 +1,4 @@
1
+ import torch
1
2
  import torch.nn.functional as F
2
3
 
3
4
  from liger_kernel.chunked_loss.fused_linear_preference import (
@@ -62,3 +63,45 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
62
63
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
63
64
  # Return these gradients, followed by None for the remaining inputs
64
65
  return *grads, None, None, None, None, None, None
66
+
67
+
68
+ class LigerFusedLinearSimPOLoss(torch.nn.Module):
69
+ """
70
+ Fused linear layer with SimPO loss.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ ignore_index: int = -100,
76
+ beta: float = 0.1,
77
+ alpha: float = 1.0,
78
+ compute_nll_loss: bool = True,
79
+ compiled: bool = True,
80
+ gamma: float = 0.5,
81
+ ):
82
+ """
83
+ Args:
84
+ ignore_index (int): Index to ignore in the loss.
85
+ beta (float): Weight for the odds ratio loss.
86
+ """
87
+ super().__init__()
88
+ self.ignore_index = ignore_index
89
+ self.beta = beta
90
+ self.alpha = alpha
91
+ self.compute_nll_loss = compute_nll_loss
92
+ self.compiled = compiled
93
+ self.gamma = gamma
94
+
95
+ def forward(self, lin_weight, _input, target, bias=None):
96
+ return LigerFusedLinearSimPOFunction.apply(
97
+ _input,
98
+ lin_weight,
99
+ target,
100
+ bias,
101
+ self.ignore_index,
102
+ self.beta,
103
+ self.alpha,
104
+ self.compute_nll_loss,
105
+ self.compiled,
106
+ self.gamma,
107
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241119223206
3
+ Version: 0.4.2.dev20241121054604
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -6,6 +6,7 @@ src/liger_kernel/env_report.py
6
6
  src/liger_kernel/chunked_loss/__init__.py
7
7
  src/liger_kernel/chunked_loss/cpo_loss.py
8
8
  src/liger_kernel/chunked_loss/dpo_loss.py
9
+ src/liger_kernel/chunked_loss/functional.py
9
10
  src/liger_kernel/chunked_loss/fused_linear_preference.py
10
11
  src/liger_kernel/chunked_loss/orpo_loss.py
11
12
  src/liger_kernel/chunked_loss/simpo_loss.py
@@ -27,10 +27,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
27
27
  bias=None,
28
28
  loss_fn=None,
29
29
  chunk_size=1,
30
- compute_nll_loss=True,
31
30
  ignore_index=-100,
32
31
  alpha=1.0,
33
32
  beta=0.1,
33
+ compute_nll_loss=True,
34
34
  compiled=True,
35
35
  **loss_kwargs,
36
36
  ):