liger-kernel-nightly 0.4.2.dev20241121224158__tar.gz → 0.4.2.dev20241122052539__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {liger_kernel_nightly-0.4.2.dev20241121224158/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241122052539}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/chunked_loss/dpo_loss.py +36 -4
  4. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/chunked_loss/fused_linear_preference.py +79 -27
  5. liger_kernel_nightly-0.4.2.dev20241122052539/src/liger_kernel/transformers/functional.py +173 -0
  6. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  7. liger_kernel_nightly-0.4.2.dev20241121224158/src/liger_kernel/transformers/functional.py +0 -58
  8. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/LICENSE +0 -0
  9. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/NOTICE +0 -0
  10. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/README.md +0 -0
  11. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/setup.cfg +0 -0
  12. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  13. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  14. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/chunked_loss/functional.py +0 -0
  15. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  16. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  17. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/env_report.py +0 -0
  18. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/__init__.py +0 -0
  19. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/cross_entropy.py +0 -0
  20. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  21. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  22. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  23. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  24. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/geglu.py +0 -0
  25. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/group_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/jsd.py +0 -0
  27. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/kl_div.py +0 -0
  28. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/layer_norm.py +0 -0
  29. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  30. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/rms_norm.py +0 -0
  31. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/rope.py +0 -0
  32. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/swiglu.py +0 -0
  33. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/ops/utils.py +0 -0
  34. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/__init__.py +0 -0
  35. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/auto_model.py +0 -0
  36. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  38. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  39. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  40. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/geglu.py +0 -0
  41. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/group_norm.py +0 -0
  42. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/jsd.py +0 -0
  43. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/kl_div.py +0 -0
  44. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/layer_norm.py +0 -0
  45. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/__init__.py +0 -0
  46. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/gemma.py +0 -0
  47. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  48. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/llama.py +0 -0
  49. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/mistral.py +0 -0
  50. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  51. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/mllama.py +0 -0
  52. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/phi3.py +0 -0
  53. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  54. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  55. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  56. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  57. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/rms_norm.py +0 -0
  58. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/rope.py +0 -0
  59. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/swiglu.py +0 -0
  60. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  61. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/triton/__init__.py +0 -0
  62. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel/triton/monkey_patch.py +0 -0
  63. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  64. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  65. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  66. {liger_kernel_nightly-0.4.2.dev20241121224158 → liger_kernel_nightly-0.4.2.dev20241122052539}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241121224158
3
+ Version: 0.4.2.dev20241122052539
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.2.dev20241121224158"
7
+ version = "0.4.2.dev20241122052539"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -9,15 +9,31 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
9
9
  class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
10
10
 
11
11
  @staticmethod
12
- def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ def preference_loss_fn(
13
+ chosen_logps,
14
+ rejected_logps,
15
+ ref_chosen_logps=None,
16
+ ref_rejected_logps=None,
17
+ beta=0.1,
18
+ ):
13
19
  """
14
20
  Compute DPO loss (Direct Preference Optimization).
15
21
  Args:
16
22
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17
23
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
24
+ ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,).
25
+ ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,).
18
26
  beta (float): Weight for the direct preference loss.
19
27
  """
20
- logits_diff = beta * (chosen_logps - rejected_logps)
28
+ if ref_chosen_logps is None:
29
+ ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
30
+ if ref_rejected_logps is None:
31
+ ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
32
+
33
+ chosen_logratios = chosen_logps - ref_chosen_logps
34
+ rejected_logratios = rejected_logps - ref_rejected_logps
35
+
36
+ logits_diff = beta * (chosen_logratios - rejected_logratios)
21
37
  losses = -F.logsigmoid(logits_diff)
22
38
  return losses.sum()
23
39
 
@@ -28,10 +44,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
28
44
  weight,
29
45
  target,
30
46
  bias=None,
47
+ ref_weight=None,
48
+ ref_bias=None,
31
49
  ignore_index=-100,
32
50
  beta=0.1,
33
51
  compute_nll_loss=True,
34
52
  compiled=True,
53
+ use_ref_model=True,
35
54
  ):
36
55
  """
37
56
  Fused linear layer with DPO (Direct Preference Optimization) loss.
@@ -48,6 +67,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
48
67
  beta=beta,
49
68
  compute_nll_loss=compute_nll_loss,
50
69
  compiled=compiled,
70
+ use_ref_model=use_ref_model,
71
+ ref_weight=ref_weight,
72
+ ref_bias=ref_bias,
51
73
  )
52
74
 
53
75
  @staticmethod
@@ -55,7 +77,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
55
77
  # Get gradients for _input, weight, bias, and target from the base class
56
78
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
57
79
  # Return these gradients, followed by None for the remaining inputs
58
- return *grads, None, None, None, None
80
+ return *grads, None, None, None, None, None, None, None
59
81
 
60
82
 
61
83
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -69,26 +91,36 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
69
91
  beta: float = 0.1,
70
92
  compute_nll_loss: bool = True,
71
93
  compiled: bool = True,
94
+ use_ref_model: bool = False,
72
95
  ):
73
96
  """
74
97
  Args:
75
98
  ignore_index (int): Index to ignore in the loss.
76
99
  beta (float): Weight for the odds ratio loss.
100
+ compute_nll_loss (bool): Whether to compute the NLL loss.
101
+ compiled (bool): Whether to use the torch compiled kernel.
102
+ use_ref_model (bool): Whether to use a reference model for the DPO loss.
77
103
  """
78
104
  super().__init__()
79
105
  self.ignore_index = ignore_index
80
106
  self.beta = beta
81
107
  self.compute_nll_loss = compute_nll_loss
82
108
  self.compiled = compiled
109
+ self.use_ref_model = use_ref_model
83
110
 
84
- def forward(self, lin_weight, _input, target, bias=None):
111
+ def forward(
112
+ self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
113
+ ):
85
114
  return LigerFusedLinearDPOFunction.apply(
86
115
  _input,
87
116
  lin_weight,
88
117
  target,
89
118
  bias,
119
+ ref_weight,
120
+ ref_bias,
90
121
  self.ignore_index,
91
122
  self.beta,
92
123
  self.compute_nll_loss,
93
124
  self.compiled,
125
+ self.use_ref_model,
94
126
  )
@@ -18,6 +18,42 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
18
18
  """
19
19
  raise NotImplementedError("Preference loss function must be implemented.")
20
20
 
21
+ @staticmethod
22
+ def chunk_forward(
23
+ input_chunk,
24
+ weight,
25
+ target_chunk,
26
+ bias=None,
27
+ ignore_index=-100,
28
+ compute_nll_loss=True,
29
+ ):
30
+ len_chosen_chunk = target_chunk.shape[0] // 2
31
+ logits_chunk = input_chunk @ weight.t()
32
+ if bias is not None:
33
+ logits_chunk = logits_chunk + bias
34
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
35
+
36
+ chosen_nll_loss = 0.0
37
+ if compute_nll_loss:
38
+ chosen_nll_loss = F.nll_loss(
39
+ log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
40
+ target_chunk[:len_chosen_chunk].view(-1),
41
+ reduction="sum",
42
+ ignore_index=ignore_index,
43
+ )
44
+
45
+ loss_mask = target_chunk != ignore_index
46
+ label_chunk = torch.where(loss_mask, target_chunk, 0)
47
+
48
+ per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
49
+ -1
50
+ )
51
+ average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
52
+
53
+ chosen_logps = average_log_prob[:len_chosen_chunk]
54
+ rejected_logps = average_log_prob[len_chosen_chunk:]
55
+ return chosen_logps, rejected_logps, chosen_nll_loss
56
+
21
57
  @staticmethod
22
58
  def forward(
23
59
  ctx,
@@ -32,6 +68,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
32
68
  beta=0.1,
33
69
  compute_nll_loss=True,
34
70
  compiled=True,
71
+ use_ref_model=False,
72
+ ref_weight=None,
73
+ ref_bias=None,
35
74
  **loss_kwargs,
36
75
  ):
37
76
  """
@@ -49,7 +88,11 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
49
88
  ignore_index (int): Index to ignore for loss computation.
50
89
  alpha (float): Weight for the NLL loss.
51
90
  beta (float): Weight for the odds ratio loss.
91
+ compute_nll_loss (bool): Whether to compute NLL loss.
52
92
  compiled (bool): Whether to use torch compile for chunk accumulation.
93
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
94
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
95
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
53
96
  loss_kwargs (dict): Other possible arguments that a loss function might need
54
97
  """
55
98
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -61,7 +104,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
61
104
  grad_bias = torch.zeros_like(bias) if bias is not None else None
62
105
  loss_acc = torch.zeros((), device=_input.device)
63
106
 
64
- chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
65
107
  loss_func_to_call = partial(
66
108
  LigerFusedLinearPreferenceBase._compute_loss,
67
109
  preference_loss_fn=loss_fn,
@@ -70,6 +112,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
70
112
  beta=beta,
71
113
  compute_nll_loss=compute_nll_loss,
72
114
  full_target=target,
115
+ use_ref_model=use_ref_model,
116
+ ref_weight=ref_weight,
117
+ ref_bias=ref_bias,
73
118
  **loss_kwargs,
74
119
  )
75
120
 
@@ -101,6 +146,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
101
146
  accumulate_chunk = torch.compile(accumulate_chunk)
102
147
 
103
148
  len_chosen = target.shape[0] // 2
149
+ chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
104
150
  _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
105
151
  _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
106
152
  _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
@@ -159,6 +205,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
159
205
  alpha=1.0,
160
206
  beta=0.1,
161
207
  compute_nll_loss=True,
208
+ use_ref_model=False,
209
+ ref_weight=None,
210
+ ref_bias=None,
162
211
  **loss_kwargs,
163
212
  ):
164
213
  """
@@ -173,38 +222,41 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
173
222
  ignore_index (int): Index to ignore for loss computation.
174
223
  alpha (float): Weight for the NLL loss.
175
224
  beta (float): Weight for the odds ratio loss.
225
+ compute_nll_loss (bool): Whether to compute NLL loss.
226
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
227
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
228
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
176
229
  loss_kwargs (dict): Additional arguments for the loss function.
177
230
  """
178
- len_chosen_chunk = target_chunk.shape[0] // 2
179
-
180
- logits_chunk = input_chunk @ weight.t() # chunk_size x V
181
- if bias is not None:
182
- logits_chunk = logits_chunk + bias
183
- log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
184
-
185
- chosen_nll_loss = 0.0
186
- if compute_nll_loss:
187
- chosen_nll_loss = F.nll_loss(
188
- log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
189
- target_chunk[:len_chosen_chunk].view(-1),
190
- reduction="sum",
231
+ chosen_logps, rejected_logps, chosen_nll_loss = (
232
+ LigerFusedLinearPreferenceBase.chunk_forward(
233
+ input_chunk,
234
+ weight,
235
+ target_chunk,
236
+ bias=bias,
191
237
  ignore_index=ignore_index,
238
+ compute_nll_loss=compute_nll_loss,
192
239
  )
193
- chosen_nll_loss = (
194
- chosen_nll_loss
195
- / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
196
- )
197
-
198
- loss_mask = target_chunk != ignore_index
199
- label_chunk = torch.where(loss_mask, target_chunk, 0)
200
-
201
- per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
202
- -1
203
240
  )
204
- average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
241
+ chosen_nll_loss = (
242
+ chosen_nll_loss
243
+ / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
244
+ )
205
245
 
206
- chosen_logps = average_log_prob[:len_chosen_chunk]
207
- rejected_logps = average_log_prob[len_chosen_chunk:]
246
+ if use_ref_model:
247
+ with torch.no_grad():
248
+ ref_chosen_logps, ref_rejected_logps, _ = (
249
+ LigerFusedLinearPreferenceBase.chunk_forward(
250
+ input_chunk,
251
+ ref_weight,
252
+ target_chunk,
253
+ ref_bias,
254
+ ignore_index=ignore_index,
255
+ compute_nll_loss=False,
256
+ )
257
+ )
258
+ loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
259
+ loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
208
260
 
209
261
  alignment_loss = preference_loss_fn(
210
262
  chosen_logps, rejected_logps, beta=beta, **loss_kwargs
@@ -0,0 +1,173 @@
1
+ from typing import Optional
2
+
3
+ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
+ from liger_kernel.ops.fused_linear_cross_entropy import (
5
+ LigerFusedLinearCrossEntropyFunction,
6
+ )
7
+ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
+ from liger_kernel.ops.geglu import LigerGELUMulFunction
9
+ from liger_kernel.ops.group_norm import LigerGroupNormFunction
10
+ from liger_kernel.ops.jsd import LigerJSDFunction
11
+ from liger_kernel.ops.kl_div import LigerKLDivLossFunction
12
+ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
14
+ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
15
+ from liger_kernel.ops.rope import LigerRopeFunction
16
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
17
+
18
+
19
+ # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
20
+ # `weight` and `size_average` are placeholders and not implemented yet
21
+ def liger_cross_entropy(
22
+ input,
23
+ target,
24
+ weight=None,
25
+ size_average=None,
26
+ ignore_index: int = -100,
27
+ reduce=None,
28
+ reduction: str = "mean",
29
+ label_smoothing: float = 0.0,
30
+ lse_square_scale: float = 0.0,
31
+ softcap: Optional[float] = None,
32
+ return_z_loss: bool = False,
33
+ ):
34
+ loss, z_loss = LigerCrossEntropyFunction.apply(
35
+ input,
36
+ target,
37
+ ignore_index,
38
+ lse_square_scale,
39
+ label_smoothing,
40
+ reduction,
41
+ softcap,
42
+ return_z_loss,
43
+ )
44
+ if not return_z_loss:
45
+ return loss
46
+ return loss, z_loss
47
+
48
+
49
+ def liger_fused_linear_cross_entropy(
50
+ input,
51
+ weight,
52
+ target,
53
+ bias=None,
54
+ ignore_index: int = -100,
55
+ lse_square_scale: float = 0.0,
56
+ label_smoothing: float = 0.0,
57
+ reduction: str = "mean",
58
+ softcap: Optional[float] = None,
59
+ ):
60
+ return LigerFusedLinearCrossEntropyFunction.apply(
61
+ input,
62
+ weight,
63
+ target,
64
+ bias,
65
+ ignore_index,
66
+ lse_square_scale,
67
+ label_smoothing,
68
+ reduction,
69
+ softcap,
70
+ )
71
+
72
+
73
+ def liger_fused_linear_jsd(
74
+ student_input,
75
+ student_weight,
76
+ teacher_input,
77
+ teacher_weight,
78
+ shift_labels=None,
79
+ jsd_beta: float = 0.5,
80
+ ignore_index: int = -100,
81
+ temperature: float = 1.0,
82
+ ):
83
+ return LigerFusedLinearJSDFunction.apply(
84
+ student_input,
85
+ student_weight,
86
+ teacher_input,
87
+ teacher_weight,
88
+ shift_labels,
89
+ jsd_beta,
90
+ ignore_index,
91
+ temperature,
92
+ )
93
+
94
+
95
+ def liger_geglu(a, b):
96
+ return LigerGELUMulFunction.apply(a, b)
97
+
98
+
99
+ def liger_group_norm(
100
+ X,
101
+ affine_scaling_weight,
102
+ affine_shifting_bias,
103
+ num_channels,
104
+ num_groups,
105
+ eps,
106
+ ):
107
+ return LigerGroupNormFunction.apply(
108
+ X,
109
+ affine_scaling_weight,
110
+ affine_shifting_bias,
111
+ num_channels,
112
+ num_groups,
113
+ eps,
114
+ )
115
+
116
+
117
+ def liger_jsd(
118
+ input,
119
+ target,
120
+ shift_labels=None,
121
+ beta: float = 0.5,
122
+ ignore_index: int = -100,
123
+ ):
124
+ return LigerJSDFunction.apply(
125
+ input,
126
+ target,
127
+ shift_labels,
128
+ beta,
129
+ ignore_index,
130
+ )
131
+
132
+
133
+ # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
134
+ # `size_average` and `mean` are being deprecated in torch API and are placeholders here
135
+ def liger_kl_div(
136
+ input,
137
+ target,
138
+ size_average: bool = True,
139
+ reduce: bool = True,
140
+ reduction: str = "mean",
141
+ log_target: bool = False,
142
+ eps: float = 1e-10,
143
+ ):
144
+ # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
145
+ return LigerKLDivLossFunction.apply(
146
+ input,
147
+ target,
148
+ reduction,
149
+ log_target,
150
+ eps,
151
+ )
152
+
153
+
154
+ def liger_layer_norm(X, W, B, eps):
155
+ return LigerLayerNormFunction.apply(X, W, B, eps)
156
+
157
+
158
+ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159
+ return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160
+
161
+
162
+ def liger_rms_norm(
163
+ X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164
+ ):
165
+ return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166
+
167
+
168
+ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
169
+ return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
170
+
171
+
172
+ def liger_swiglu(a, b):
173
+ return LigerSiLUMulFunction.apply(a, b)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241121224158
3
+ Version: 0.4.2.dev20241122052539
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -1,58 +0,0 @@
1
- from typing import Optional
2
-
3
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
- from liger_kernel.ops.fused_linear_cross_entropy import (
5
- LigerFusedLinearCrossEntropyFunction,
6
- )
7
- from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
- from liger_kernel.ops.geglu import LigerGELUMulFunction
9
- from liger_kernel.ops.group_norm import LigerGroupNormFunction
10
- from liger_kernel.ops.jsd import LigerJSDFunction
11
- from liger_kernel.ops.kl_div import LigerKLDivLossFunction
12
- from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13
- from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
14
- from liger_kernel.ops.rms_norm import LigerRMSNormFunction
15
- from liger_kernel.ops.rope import LigerRopeFunction
16
- from liger_kernel.ops.swiglu import LigerSiLUMulFunction
17
-
18
- liger_swiglu = LigerSiLUMulFunction.apply
19
- liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
20
- liger_geglu = LigerGELUMulFunction.apply
21
- liger_rms_norm = LigerRMSNormFunction.apply
22
- liger_rope = LigerRopeFunction.apply
23
- liger_qwen2vl_mrope = LigerQwen2VLMRopeFunction.apply
24
- liger_layer_norm = LigerLayerNormFunction.apply
25
- liger_kl_div = LigerKLDivLossFunction.apply
26
- liger_jsd = LigerJSDFunction.apply
27
- liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
28
- liger_group_norm = LigerGroupNormFunction.apply
29
-
30
-
31
- # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
32
- # `weight` and `size_average` are placeholders and not implemented yet
33
- def liger_cross_entropy(
34
- input,
35
- target,
36
- weight=None,
37
- size_average=None,
38
- ignore_index: int = -100,
39
- reduce=None,
40
- reduction: str = "mean",
41
- label_smoothing: float = 0.0,
42
- lse_square_scale: float = 0.0,
43
- softcap: Optional[float] = None,
44
- return_z_loss: bool = False,
45
- ):
46
- loss, z_loss = LigerCrossEntropyFunction.apply(
47
- input,
48
- target,
49
- ignore_index,
50
- lse_square_scale,
51
- label_smoothing,
52
- reduction,
53
- softcap,
54
- return_z_loss,
55
- )
56
- if not return_z_loss:
57
- return loss
58
- return loss, z_loss