liger-kernel-nightly 0.5.2.dev20241212030605__py3-none-any.whl → 0.5.2.dev20241212055403__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,6 +59,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
59
59
  weight,
60
60
  target,
61
61
  bias=None,
62
+ ref_input=None,
62
63
  ref_weight=None,
63
64
  ref_bias=None,
64
65
  ignore_index=-100,
@@ -79,6 +80,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
79
80
  compute_nll_loss=compute_nll_loss,
80
81
  compiled=compiled,
81
82
  use_ref_model=use_ref_model,
83
+ ref_input=ref_input,
82
84
  ref_weight=ref_weight,
83
85
  ref_bias=ref_bias,
84
86
  )
@@ -86,7 +88,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
86
88
  @staticmethod
87
89
  def backward(ctx, *grad_output):
88
90
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
89
- return *grads, None, None, None, None, None, None, None
91
+ return *grads, None, None, None, None, None, None, None, None
90
92
 
91
93
 
92
94
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -118,13 +120,21 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
118
120
  self.use_ref_model = use_ref_model
119
121
 
120
122
  def forward(
121
- self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
123
+ self,
124
+ lin_weight,
125
+ _input,
126
+ target,
127
+ bias=None,
128
+ ref_input=None,
129
+ ref_weight=None,
130
+ ref_bias=None,
122
131
  ):
123
132
  return LigerFusedLinearDPOFunction.apply(
124
133
  _input,
125
134
  lin_weight,
126
135
  target,
127
136
  bias,
137
+ ref_input,
128
138
  ref_weight,
129
139
  ref_bias,
130
140
  self.ignore_index,
@@ -29,7 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
29
29
  compute_nll_loss=True,
30
30
  compiled=True,
31
31
  use_ref_model=False,
32
- # TODO: ref input
32
+ ref_input=None,
33
33
  ref_weight=None,
34
34
  ref_bias=None,
35
35
  **loss_kwargs,
@@ -97,20 +97,26 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
97
97
  **loss_kwargs,
98
98
  )
99
99
 
100
- def fused_fwd_bwd(input_chunk, target_chunk):
100
+ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
101
101
  """
102
102
  Fused forward and backward pass for a chunk of input and target.
103
103
  """
104
104
  if bias is not None:
105
105
  return torch.func.grad_and_value(
106
106
  compute_loss, argnums=(0, 1, 3), has_aux=True
107
- )(input_chunk, weight, target_chunk, bias)
107
+ )(
108
+ input_chunk,
109
+ weight,
110
+ target_chunk,
111
+ bias,
112
+ ref_input_chunk=ref_input_chunk,
113
+ )
108
114
  else:
109
115
  return torch.func.grad_and_value(
110
116
  compute_loss, argnums=(0, 1), has_aux=True
111
- )(input_chunk, weight, target_chunk)
117
+ )(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk)
112
118
 
113
- def accumulate_chunk(input_chunk, target_chunk):
119
+ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
114
120
  if bias is not None:
115
121
  (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
116
122
  chunk_loss,
@@ -122,7 +128,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
122
128
  chunk_nll_loss,
123
129
  *aux_outputs,
124
130
  ),
125
- ) = fused_fwd_bwd(input_chunk, target_chunk)
131
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
126
132
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
127
133
  else:
128
134
  (chunk_grad_input, chunk_grad_weight), (
@@ -135,7 +141,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
135
141
  chunk_nll_loss,
136
142
  *aux_outputs,
137
143
  ),
138
- ) = fused_fwd_bwd(input_chunk, target_chunk)
144
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
139
145
 
140
146
  # Accumulate gradients
141
147
  grad_weight.add_(chunk_grad_weight)
@@ -182,18 +188,43 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
182
188
  _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
183
189
  _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
184
190
 
191
+ if use_ref_model:
192
+ _ref_chosen_input_chunks = torch.chunk(
193
+ ref_input[:len_chosen], chunks=chunks, dim=0
194
+ )
195
+ _ref_rejected_input_chunks = torch.chunk(
196
+ ref_input[len_chosen:], chunks=chunks, dim=0
197
+ )
198
+
185
199
  for (
186
200
  chosen_input_chunk,
187
201
  rejected_input_chunk,
188
202
  chosen_target_chunk,
189
203
  rejected_target_chunk,
204
+ ref_chosen_input_chunk,
205
+ ref_rejected_input_chunk,
190
206
  ) in zip(
191
207
  _chosen_input_chunks,
192
208
  _rejected_input_chunks,
193
209
  _chosen_target_chunks,
194
210
  _rejected_target_chunks,
211
+ (
212
+ _ref_chosen_input_chunks
213
+ if use_ref_model
214
+ else [None] * len(_chosen_input_chunks)
215
+ ),
216
+ (
217
+ _ref_rejected_input_chunks
218
+ if use_ref_model
219
+ else [None] * len(_rejected_input_chunks)
220
+ ),
195
221
  ):
196
222
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
223
+ ref_input_chunk = (
224
+ torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
225
+ if use_ref_model
226
+ else None
227
+ )
197
228
  target_chunk = torch.cat(
198
229
  [chosen_target_chunk, rejected_target_chunk], dim=0
199
230
  )
@@ -202,9 +233,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
202
233
  torch._dynamo.mark_dynamic(input_chunk, 1)
203
234
  torch._dynamo.mark_dynamic(target_chunk, 1)
204
235
  torch._dynamo.mark_dynamic(target, 1)
236
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
205
237
 
206
238
  # accumulate loss, gradients, and metrics
207
- accumulate_chunk(input_chunk, target_chunk)
239
+ accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
208
240
 
209
241
  # combine grad_chosen_inputs and grad_rejected_inputs
210
242
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
@@ -301,6 +333,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
301
333
  beta=0.1,
302
334
  compute_nll_loss=True,
303
335
  use_ref_model=False,
336
+ ref_input_chunk=None,
304
337
  ref_weight=None,
305
338
  ref_bias=None,
306
339
  **loss_kwargs,
@@ -357,7 +390,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
357
390
  ref_rejected_logits,
358
391
  ref_chosen_nll_loss,
359
392
  ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
- input_chunk,
393
+ ref_input_chunk,
361
394
  ref_weight,
362
395
  target_chunk,
363
396
  ref_bias,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241212030605
3
+ Version: 0.5.2.dev20241212055403
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,10 +4,10 @@ liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Qu1Ul2A12sp6CqIT-atPbHWFb_LLtINEA9mOpIRx_0g,3097
7
- liger_kernel/chunked_loss/dpo_loss.py,sha256=H9_RRhclckHYM2sd75tgbnf8IxC_PU2JCALbgtPQvwc,4222
7
+ liger_kernel/chunked_loss/dpo_loss.py,sha256=9S67SzKkLyoBmHGx8bkmthSNHlCT2ikBy9CCdb7wGj0,4381
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
10
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vlWfaaIECWvCQhY9PM7zRI0vKThIrydMf6P44bXn1EE,15114
10
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=AsovMdfsOjgWVxtDhZ_rXqpahMsKTg8YueXnZcHt1XQ,16376
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=ZuKGjbkIYzV4UzvupNdq6vyxCp7-BztQkUt8ZnFvKos,3531
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=Wa4LOlDG9PbJkOOkKg8hbKvnKgg7OTBz6-qIkwPK1yw,3275
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
59
59
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
60
60
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
61
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/METADATA,sha256=PeHGuXRXme-T4S249Fh6IWDCNH2-DMWzhyrs2i9MiyE,20260
63
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/METADATA,sha256=kgWj3JislmGGGRKJjny99tjjHfqgdim23iNGAeZhwfk,20260
63
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/RECORD,,