liger-kernel-nightly 0.5.5.dev20250314002525__py3-none-any.whl → 0.5.5.dev20250314203927__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

@@ -117,7 +117,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
117
117
 
118
118
  hard_loss /= full_target.shape[0]
119
119
 
120
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
120
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
121
121
  soft_loss /= full_target.shape[0]
122
122
 
123
123
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -180,9 +180,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
180
  ignore_index=ignore_index,
181
181
  weight_hard_loss=weight_hard_loss,
182
182
  weight_soft_loss=weight_soft_loss,
183
- beta=beta,
184
183
  compute_ce_loss=compute_ce_loss,
185
184
  temperature=temperature,
185
+ beta=beta,
186
186
  **loss_kwargs,
187
187
  )
188
188
 
@@ -19,15 +19,20 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
19
19
  student_log_probs = F.log_softmax(student_logits, dim=-1)
20
20
  teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
21
21
 
22
- # Compute probabilities (only required for mean calculation)
23
- mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
24
- log_mean_probs = mean_probs.log()
22
+ if beta == 0:
23
+ jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
24
+ elif beta == 1:
25
+ jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
26
+ else:
27
+ # Compute probabilities (only required for mean calculation)
28
+ mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
29
+ log_mean_probs = mean_probs.log()
25
30
 
26
- student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
27
- teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
31
+ student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
32
+ teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
28
33
 
29
- # JSD is the weighted average of the KL divergences
30
- jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
34
+ # JSD is the weighted average of the KL divergences
35
+ jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
31
36
  return jsd_loss
32
37
 
33
38
  @classmethod
liger_kernel/ops/jsd.py CHANGED
@@ -51,24 +51,43 @@ def _jsd_kernel(
51
51
  Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
52
 
53
53
  if beta == 0.0: # forward KL
54
- Y_prob = tl.exp(Y)
54
+ Y_max = tl.max(Y, axis=0)
55
+ Y_shifted = Y - Y_max
56
+ Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
55
57
  loss = Y_prob * (Y - X)
56
58
  dX = -Y_prob
57
- elif beta == 1.0:
58
- X_prob = tl.exp(X)
59
+ elif beta == 1.0: # reverse KL
60
+ X_max = tl.max(X, axis=0)
61
+ X_shifted = X - X_max
62
+ X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
59
63
  loss = X_prob * (X - Y)
60
64
  dX = loss + X_prob
61
65
  else:
62
- Q = tl.exp(X)
63
- P = tl.exp(Y)
64
- M = beta * P + (1 - beta) * Q
65
- log_M = tl.log(M)
66
+ max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
67
+ X_shifted = X - max_val
68
+ Y_shifted = Y - max_val
66
69
 
67
- loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
68
- dX = (1 - beta) * Q * (X - log_M)
70
+ # Pre-compute exp(max_val) since it's used twice
71
+ exp_max = tl.exp(max_val)
72
+
73
+ # Compute exp terms with compensation
74
+ Q = tl.exp(X_shifted) * exp_max # = exp(X)
75
+ P = tl.exp(Y_shifted) * exp_max # = exp(Y)
76
+
77
+ # Pre-compute common terms
78
+ beta_P = beta * P
79
+ one_minus_beta_Q = (1 - beta) * Q
80
+ M = beta_P + one_minus_beta_Q
81
+ log_M = tl.log(M) # No need to compensate as M is already in original scale
82
+
83
+ loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
84
+ dX = one_minus_beta_Q * (X - log_M)
85
+
86
+ # Pre-compute scaling factor
87
+ scale = 1.0 / n_non_ignore
88
+ loss = loss * scale
89
+ dX = dX * scale
69
90
 
70
- loss = loss / n_non_ignore
71
- dX = dX / n_non_ignore
72
91
  tl.store(loss_ptr + offsets, loss, mask=mask)
73
92
  tl.store(dX_ptr + offsets, dX, mask=mask)
74
93
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250314002525
3
+ Version: 0.5.5.dev20250314203927
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -6,12 +6,12 @@ liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIu
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
8
8
  liger_kernel/chunked_loss/functional.py,sha256=THWWpCnRVhTVfnPnyvQjdBvo1JDtxhwLmtZE_yiBBqM,817
9
- liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=FJh7k3sry-fqnBApLSngf7h-lHQEiXtOY_tiRDVanPM,11022
9
+ liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=oeZhRw87UUo01UotfaMxDhWa7Xr6IERmK3zzF1CQqEc,11037
10
10
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
11
11
  liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=wGujqwLz91mOE9MmdenhBIKvbmswhwtINMCpcP7D74c,9050
12
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
13
13
  liger_kernel/chunked_loss/grpo_loss.py,sha256=axED3628yKODu1v7PMAvSd08WZqwNQvJOTUYMgcihdQ,6665
14
- liger_kernel/chunked_loss/jsd_loss.py,sha256=j2_1AYLu0FW2VQJIEr1J1qHsWd5VUo6C3aedglHVH4Y,6771
14
+ liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
15
15
  liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
17
17
  liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
@@ -21,7 +21,7 @@ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJ
21
21
  liger_kernel/ops/fused_linear_jsd.py,sha256=Seshez2qaM6HiTQ8_HEqSwhaeVruNT1SvIM4ZrAPBEU,9602
22
22
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
23
23
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
24
- liger_kernel/ops/jsd.py,sha256=0jNeRxpcNI5ckxCdoCNyO5GEedLIuzx3lz6KAiksc4o,6109
24
+ liger_kernel/ops/jsd.py,sha256=rkloGA7nDfVaa5nKY6-EYBw0E1p_MSsl4fr2xZGTp04,6961
25
25
  liger_kernel/ops/kl_div.py,sha256=MnfuYqqQESON1X2Swy064x1urKtMFdgeSWd60VttBXI,8420
26
26
  liger_kernel/ops/layer_norm.py,sha256=6roQjioyg-9O2qLPV8nL4U0-5UH80tdzOMTWwjvDnn8,7961
27
27
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
@@ -67,9 +67,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
67
67
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
68
68
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
69
69
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
70
- liger_kernel_nightly-0.5.5.dev20250314002525.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
71
- liger_kernel_nightly-0.5.5.dev20250314002525.dist-info/METADATA,sha256=auKRFqG0RTHHc_8Sfk_3RfEmSimuwPwSmbsCSRpNhgU,22390
72
- liger_kernel_nightly-0.5.5.dev20250314002525.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
73
- liger_kernel_nightly-0.5.5.dev20250314002525.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
74
- liger_kernel_nightly-0.5.5.dev20250314002525.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
75
- liger_kernel_nightly-0.5.5.dev20250314002525.dist-info/RECORD,,
70
+ liger_kernel_nightly-0.5.5.dev20250314203927.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
71
+ liger_kernel_nightly-0.5.5.dev20250314203927.dist-info/METADATA,sha256=Fomxuo8mGYVe9Um1hCaEKQ0PyfYic7JJfatd3BZIrz0,22390
72
+ liger_kernel_nightly-0.5.5.dev20250314203927.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
73
+ liger_kernel_nightly-0.5.5.dev20250314203927.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
74
+ liger_kernel_nightly-0.5.5.dev20250314203927.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
75
+ liger_kernel_nightly-0.5.5.dev20250314203927.dist-info/RECORD,,