liger-kernel-nightly 0.5.2.dev20250108072837__py3-none-any.whl → 0.5.2.dev20250108102127__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -65,6 +65,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
65
65
  beta=beta,
66
66
  label_smoothing=label_smoothing,
67
67
  compute_nll_loss=compute_nll_loss,
68
+ average_log_prob=False,
68
69
  compiled=compiled,
69
70
  )
70
71
 
@@ -32,6 +32,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
32
32
  ref_input=None,
33
33
  ref_weight=None,
34
34
  ref_bias=None,
35
+ average_log_prob=True,
35
36
  **loss_kwargs,
36
37
  ):
37
38
  """
@@ -61,6 +62,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
61
62
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
62
63
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
63
64
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
65
+ average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
64
66
  loss_kwargs (dict): Other possible arguments that a loss function might need
65
67
  """
66
68
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -94,6 +96,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
94
96
  use_ref_model=use_ref_model,
95
97
  ref_weight=ref_weight,
96
98
  ref_bias=ref_bias,
99
+ average_log_prob=average_log_prob,
97
100
  **loss_kwargs,
98
101
  )
99
102
 
@@ -265,6 +268,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
265
268
  bias=None,
266
269
  ignore_index=-100,
267
270
  compute_nll_loss=True,
271
+ average_log_prob=True,
268
272
  ):
269
273
  len_chosen_chunk = target_chunk.shape[0] // 2
270
274
  logits_chunk = input_chunk @ weight.t()
@@ -285,10 +289,13 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
285
289
  label_chunk = torch.where(loss_mask, target_chunk, 0)
286
290
 
287
291
  per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
288
- average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
292
+ if average_log_prob:
293
+ log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
294
+ else:
295
+ log_prob = (per_token_logps * loss_mask).sum(-1)
289
296
 
290
- chosen_logps = average_log_prob[:len_chosen_chunk]
291
- rejected_logps = average_log_prob[len_chosen_chunk:]
297
+ chosen_logps = log_prob[:len_chosen_chunk]
298
+ rejected_logps = log_prob[len_chosen_chunk:]
292
299
 
293
300
  chosen_logits = logits_chunk[:len_chosen_chunk]
294
301
  rejected_logits = logits_chunk[len_chosen_chunk:]
@@ -317,6 +324,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
317
324
  ref_input_chunk=None,
318
325
  ref_weight=None,
319
326
  ref_bias=None,
327
+ average_log_prob=True,
320
328
  **loss_kwargs,
321
329
  ):
322
330
  """
@@ -335,6 +343,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
335
343
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
336
344
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
337
345
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
346
+ average_log_prob (bool): Whether to average log probabilities or the sum.
338
347
  loss_kwargs (dict): Additional arguments for the loss function.
339
348
  """
340
349
  (
@@ -350,6 +359,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
350
359
  bias=bias,
351
360
  ignore_index=ignore_index,
352
361
  compute_nll_loss=compute_nll_loss,
362
+ average_log_prob=average_log_prob,
353
363
  )
354
364
  chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
355
365
  chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
@@ -372,6 +382,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
372
382
  ref_bias,
373
383
  ignore_index=ignore_index,
374
384
  compute_nll_loss=False, # We don't need NLL loss for the reference model
385
+ average_log_prob=average_log_prob,
375
386
  )
376
387
  loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
377
388
  loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
@@ -20,9 +20,6 @@ if compare_version("triton", operator.ge, "3.0.0"):
20
20
  else:
21
21
  from triton.language.math import tanh
22
22
 
23
- _TRUE: tl.constexpr = tl.constexpr(1)
24
- _FALSE: tl.constexpr = tl.constexpr(0)
25
-
26
23
 
27
24
  @triton.jit
28
25
  def liger_cross_entropy_kernel(
@@ -95,7 +92,7 @@ def liger_cross_entropy_kernel(
95
92
  return
96
93
 
97
94
  loss_ptr += program_id * loss_stride
98
- if RETURN_Z_LOSS == _TRUE:
95
+ if RETURN_Z_LOSS:
99
96
  z_loss_ptr += program_id * loss_stride
100
97
 
101
98
  if HAS_WEIGHT:
@@ -254,7 +251,7 @@ def liger_cross_entropy_kernel(
254
251
  loss += z_loss
255
252
 
256
253
  tl.store(loss_ptr, loss)
257
- if RETURN_Z_LOSS == _TRUE:
254
+ if RETURN_Z_LOSS:
258
255
  tl.store(z_loss_ptr, z_loss)
259
256
 
260
257
 
@@ -264,12 +261,6 @@ def liger_cross_entropy_kernel(
264
261
  MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
265
262
 
266
263
 
267
- _bool_to_return_z_loss = {
268
- True: _TRUE.value,
269
- False: _FALSE.value,
270
- }
271
-
272
-
273
264
  def cross_entropy_forward(
274
265
  _input,
275
266
  target,
@@ -281,11 +272,7 @@ def cross_entropy_forward(
281
272
  softcap,
282
273
  return_z_loss,
283
274
  ):
284
- if not isinstance(return_z_loss, int):
285
- assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
286
- return_z_loss = _bool_to_return_z_loss[return_z_loss]
287
- else:
288
- assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
275
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
289
276
 
290
277
  BT, V = _input.shape
291
278
  n_rows = BT
@@ -294,10 +281,7 @@ def cross_entropy_forward(
294
281
 
295
282
  # unreduced loss
296
283
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
297
- if return_z_loss == _TRUE.value:
298
- z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
299
- else:
300
- z_loss_1d = None # set None when return_z_loss == False
284
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
301
285
 
302
286
  target_mask = target != ignore_index
303
287
  n_non_ignore = target_mask.sum().item()
@@ -326,7 +310,7 @@ def cross_entropy_forward(
326
310
  X_stride=_input.stride(-2),
327
311
  Y_ptr=target,
328
312
  Y_stride=target.stride(-1), # always 1
329
- weight_ptr=weight if weight is not None else _input, # dummy if None
313
+ weight_ptr=weight, # dummy if None
330
314
  loss_ptr=loss_1d,
331
315
  z_loss_ptr=z_loss_1d,
332
316
  loss_stride=loss_1d.stride(-1), # always 1
@@ -338,7 +322,7 @@ def cross_entropy_forward(
338
322
  lse_square_scale=lse_square_scale,
339
323
  label_smoothing=label_smoothing,
340
324
  reduction=reduction,
341
- softcap=softcap if softcap is not None else 0.0,
325
+ softcap=softcap,
342
326
  RETURN_Z_LOSS=return_z_loss,
343
327
  BLOCK_SIZE=BLOCK_SIZE,
344
328
  HAS_WEIGHT=True if weight is not None else False,
@@ -350,10 +334,10 @@ def cross_entropy_forward(
350
334
 
351
335
  if reduction == "none":
352
336
  loss = loss_1d
353
- z_loss = z_loss_1d if return_z_loss == _TRUE.value else None
337
+ z_loss = z_loss_1d if return_z_loss else None
354
338
  else:
355
339
  loss = torch.sum(loss_1d)
356
- z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None
340
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
357
341
 
358
342
  return loss, z_loss, _input
359
343
 
@@ -92,9 +92,9 @@ def fused_linear_cross_entropy_forward(
92
92
  X_stride=logits_chunk.stride(-2),
93
93
  Y_ptr=target_chunk,
94
94
  Y_stride=target_chunk.stride(-1), # always 1
95
- weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None
95
+ weight_ptr=ce_weight,
96
96
  loss_ptr=loss_1d_slice,
97
- z_loss_ptr=loss_1d_slice, # dummy ptr, not used
97
+ z_loss_ptr=None,
98
98
  loss_stride=loss_1d_slice.stride(-1), # always 1
99
99
  n_cols=V,
100
100
  n_non_ignore=total_n_non_ignore,
@@ -104,8 +104,8 @@ def fused_linear_cross_entropy_forward(
104
104
  lse_square_scale=lse_square_scale,
105
105
  label_smoothing=label_smoothing,
106
106
  reduction=reduction,
107
- softcap=softcap if softcap is not None else 0.0,
108
- RETURN_Z_LOSS=0, # False
107
+ softcap=softcap,
108
+ RETURN_Z_LOSS=False,
109
109
  HAS_WEIGHT=True if ce_weight is not None else False,
110
110
  HAS_SOFTCAPPING=True if softcap is not None else False,
111
111
  BLOCK_SIZE=BLOCK_SIZE,
@@ -20,9 +20,6 @@ class LigerCrossEntropyLoss(torch.nn.Module):
20
20
  assert (label_smoothing >= 0) and (
21
21
  label_smoothing <= 1
22
22
  ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
23
- assert (label_smoothing >= 0) and (
24
- label_smoothing <= 1
25
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
26
23
  assert reduction in {
27
24
  "mean",
28
25
  "sum",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250108072837
3
+ Version: 0.5.2.dev20250108102127
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -3,16 +3,16 @@ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,17
3
3
  liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
6
- liger_kernel/chunked_loss/cpo_loss.py,sha256=MCR4TzuBoJEaU0IJ7dIreLacQeXLKETV5CegNjhCD9M,3646
6
+ liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
10
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
10
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=eQCZmQ3xOL3jpZ7RhOfx_pqR9sNEX6RHx8DtIgyXEHc,16656
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- liger_kernel/ops/cross_entropy.py,sha256=zi2xsa8ky7M1vySUAGjXMQDFQFkKmGQV-myRIIQM13M,19210
15
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=j7cgR95rFAwtPsWZ00PfMwis5F7dtO3EVEw0rZ1GPJk,10231
14
+ liger_kernel/ops/cross_entropy.py,sha256=SRzAF9Ek84pBVFy3wqQZs7AhRoorKRIgQ-Td_rtl1Kk,18606
15
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=hezFRwbcPc-HNGZUFqUn5AYUqUpboPpFh4MNqEW4WgU,10108
16
16
  liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
17
17
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
18
18
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
@@ -28,7 +28,7 @@ liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectfl
28
28
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
29
29
  liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
30
30
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
31
- liger_kernel/transformers/cross_entropy.py,sha256=s931h9UW_tV4QMRme1HYjS_R2_C5nD6VFmZIXtjJoYo,1840
31
+ liger_kernel/transformers/cross_entropy.py,sha256=LtiHlj_tK2YFpilwvbG_NEVzbf82zKRpWCZMjaFUd4M,1681
32
32
  liger_kernel/transformers/functional.py,sha256=B1wkHWLx-YNhxvXBEXB4Ch1yEwF3mjwTPCeXA5aCV_c,4490
33
33
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=LAN8-pjUI2Erz_MnfMer-0ZmxJ0JlKxGzdZGJY-N65g,1569
34
34
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/METADATA,sha256=HwmQEBRYnwwbdkzuW53_qsmTSSbi8qu20cVOHsq6B_s,21055
63
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/METADATA,sha256=XHrJlebOzBW0f6tV-rb0iahG9LNI-f86Ar7s-upwoxo,21055
63
+ liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/RECORD,,