liger-kernel-nightly 0.5.2.dev20241212030605__py3-none-any.whl → 0.5.2.dev20241212055403__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -59,6 +59,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
59
59
  weight,
60
60
  target,
61
61
  bias=None,
62
+ ref_input=None,
62
63
  ref_weight=None,
63
64
  ref_bias=None,
64
65
  ignore_index=-100,
@@ -79,6 +80,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
79
80
  compute_nll_loss=compute_nll_loss,
80
81
  compiled=compiled,
81
82
  use_ref_model=use_ref_model,
83
+ ref_input=ref_input,
82
84
  ref_weight=ref_weight,
83
85
  ref_bias=ref_bias,
84
86
  )
@@ -86,7 +88,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
86
88
  @staticmethod
87
89
  def backward(ctx, *grad_output):
88
90
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
89
- return *grads, None, None, None, None, None, None, None
91
+ return *grads, None, None, None, None, None, None, None, None
90
92
 
91
93
 
92
94
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -118,13 +120,21 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
118
120
  self.use_ref_model = use_ref_model
119
121
 
120
122
  def forward(
121
- self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
123
+ self,
124
+ lin_weight,
125
+ _input,
126
+ target,
127
+ bias=None,
128
+ ref_input=None,
129
+ ref_weight=None,
130
+ ref_bias=None,
122
131
  ):
123
132
  return LigerFusedLinearDPOFunction.apply(
124
133
  _input,
125
134
  lin_weight,
126
135
  target,
127
136
  bias,
137
+ ref_input,
128
138
  ref_weight,
129
139
  ref_bias,
130
140
  self.ignore_index,
@@ -29,7 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
29
29
  compute_nll_loss=True,
30
30
  compiled=True,
31
31
  use_ref_model=False,
32
- # TODO: ref input
32
+ ref_input=None,
33
33
  ref_weight=None,
34
34
  ref_bias=None,
35
35
  **loss_kwargs,
@@ -97,20 +97,26 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
97
97
  **loss_kwargs,
98
98
  )
99
99
 
100
- def fused_fwd_bwd(input_chunk, target_chunk):
100
+ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
101
101
  """
102
102
  Fused forward and backward pass for a chunk of input and target.
103
103
  """
104
104
  if bias is not None:
105
105
  return torch.func.grad_and_value(
106
106
  compute_loss, argnums=(0, 1, 3), has_aux=True
107
- )(input_chunk, weight, target_chunk, bias)
107
+ )(
108
+ input_chunk,
109
+ weight,
110
+ target_chunk,
111
+ bias,
112
+ ref_input_chunk=ref_input_chunk,
113
+ )
108
114
  else:
109
115
  return torch.func.grad_and_value(
110
116
  compute_loss, argnums=(0, 1), has_aux=True
111
- )(input_chunk, weight, target_chunk)
117
+ )(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk)
112
118
 
113
- def accumulate_chunk(input_chunk, target_chunk):
119
+ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
114
120
  if bias is not None:
115
121
  (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
116
122
  chunk_loss,
@@ -122,7 +128,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
122
128
  chunk_nll_loss,
123
129
  *aux_outputs,
124
130
  ),
125
- ) = fused_fwd_bwd(input_chunk, target_chunk)
131
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
126
132
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
127
133
  else:
128
134
  (chunk_grad_input, chunk_grad_weight), (
@@ -135,7 +141,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
135
141
  chunk_nll_loss,
136
142
  *aux_outputs,
137
143
  ),
138
- ) = fused_fwd_bwd(input_chunk, target_chunk)
144
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
139
145
 
140
146
  # Accumulate gradients
141
147
  grad_weight.add_(chunk_grad_weight)
@@ -182,18 +188,43 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
182
188
  _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
183
189
  _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
184
190
 
191
+ if use_ref_model:
192
+ _ref_chosen_input_chunks = torch.chunk(
193
+ ref_input[:len_chosen], chunks=chunks, dim=0
194
+ )
195
+ _ref_rejected_input_chunks = torch.chunk(
196
+ ref_input[len_chosen:], chunks=chunks, dim=0
197
+ )
198
+
185
199
  for (
186
200
  chosen_input_chunk,
187
201
  rejected_input_chunk,
188
202
  chosen_target_chunk,
189
203
  rejected_target_chunk,
204
+ ref_chosen_input_chunk,
205
+ ref_rejected_input_chunk,
190
206
  ) in zip(
191
207
  _chosen_input_chunks,
192
208
  _rejected_input_chunks,
193
209
  _chosen_target_chunks,
194
210
  _rejected_target_chunks,
211
+ (
212
+ _ref_chosen_input_chunks
213
+ if use_ref_model
214
+ else [None] * len(_chosen_input_chunks)
215
+ ),
216
+ (
217
+ _ref_rejected_input_chunks
218
+ if use_ref_model
219
+ else [None] * len(_rejected_input_chunks)
220
+ ),
195
221
  ):
196
222
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
223
+ ref_input_chunk = (
224
+ torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
225
+ if use_ref_model
226
+ else None
227
+ )
197
228
  target_chunk = torch.cat(
198
229
  [chosen_target_chunk, rejected_target_chunk], dim=0
199
230
  )
@@ -202,9 +233,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
202
233
  torch._dynamo.mark_dynamic(input_chunk, 1)
203
234
  torch._dynamo.mark_dynamic(target_chunk, 1)
204
235
  torch._dynamo.mark_dynamic(target, 1)
236
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
205
237
 
206
238
  # accumulate loss, gradients, and metrics
207
- accumulate_chunk(input_chunk, target_chunk)
239
+ accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
208
240
 
209
241
  # combine grad_chosen_inputs and grad_rejected_inputs
210
242
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
@@ -301,6 +333,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
301
333
  beta=0.1,
302
334
  compute_nll_loss=True,
303
335
  use_ref_model=False,
336
+ ref_input_chunk=None,
304
337
  ref_weight=None,
305
338
  ref_bias=None,
306
339
  **loss_kwargs,
@@ -357,7 +390,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
357
390
  ref_rejected_logits,
358
391
  ref_chosen_nll_loss,
359
392
  ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
- input_chunk,
393
+ ref_input_chunk,
361
394
  ref_weight,
362
395
  target_chunk,
363
396
  ref_bias,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241212030605
3
+ Version: 0.5.2.dev20241212055403
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,10 +4,10 @@ liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Qu1Ul2A12sp6CqIT-atPbHWFb_LLtINEA9mOpIRx_0g,3097
7
- liger_kernel/chunked_loss/dpo_loss.py,sha256=H9_RRhclckHYM2sd75tgbnf8IxC_PU2JCALbgtPQvwc,4222
7
+ liger_kernel/chunked_loss/dpo_loss.py,sha256=9S67SzKkLyoBmHGx8bkmthSNHlCT2ikBy9CCdb7wGj0,4381
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
10
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vlWfaaIECWvCQhY9PM7zRI0vKThIrydMf6P44bXn1EE,15114
10
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=AsovMdfsOjgWVxtDhZ_rXqpahMsKTg8YueXnZcHt1XQ,16376
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=ZuKGjbkIYzV4UzvupNdq6vyxCp7-BztQkUt8ZnFvKos,3531
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=Wa4LOlDG9PbJkOOkKg8hbKvnKgg7OTBz6-qIkwPK1yw,3275
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
59
59
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
60
60
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
61
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/METADATA,sha256=PeHGuXRXme-T4S249Fh6IWDCNH2-DMWzhyrs2i9MiyE,20260
63
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241212030605.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/METADATA,sha256=kgWj3JislmGGGRKJjny99tjjHfqgdim23iNGAeZhwfk,20260
63
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20241212055403.dist-info/RECORD,,