liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _selective_log_softmax_kernel(
8
+ LOGITS,
9
+ INPUT_IDS,
10
+ LOG_P,
11
+ MASK,
12
+ TEMPERATURE,
13
+ stride_input_ids_b,
14
+ L: tl.constexpr,
15
+ N: tl.constexpr,
16
+ BLOCK_N: tl.constexpr = 4096,
17
+ ):
18
+ off_b = tl.program_id(0).cast(tl.int64)
19
+ off_l = tl.program_id(1).cast(tl.int64)
20
+
21
+ LOGITS += off_b * (L + 1) * N + off_l * N
22
+ INPUT_IDS += off_b * stride_input_ids_b + off_l
23
+ LOG_P += off_b * L + off_l
24
+
25
+ if MASK is not None:
26
+ MASK += off_b * stride_input_ids_b + off_l
27
+ not_skip = tl.load(MASK)
28
+ if not_skip == 0:
29
+ return
30
+
31
+ m_i = float("-inf")
32
+ l_i = 0.0
33
+ for start in range(0, N, BLOCK_N):
34
+ cols = start + tl.arange(0, BLOCK_N)
35
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
36
+ new_m_i = tl.maximum(m_i, tl.max(logits))
37
+ alpha = tl.exp(m_i - new_m_i)
38
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
39
+ m_i = new_m_i
40
+ lse = m_i + tl.log(l_i)
41
+
42
+ ids = tl.load(INPUT_IDS)
43
+ x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
44
+ logp = x - lse
45
+ tl.store(LOG_P, logp)
46
+
47
+
48
+ # compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
49
+ @torch.no_grad
50
+ def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
51
+ assert logits.is_contiguous()
52
+ B, L_ADD_1, N = logits.shape
53
+ L = L_ADD_1 - 1
54
+ input_ids = input_ids[:, -L:]
55
+ if mask is not None:
56
+ mask = mask[:, -L:]
57
+ log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
58
+ kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
59
+ _selective_log_softmax_kernel[(B, L)](
60
+ logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
61
+ )
62
+ return log_p
63
+
64
+
65
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
66
+ # for BLOCK_N in [2048, 4096, 8192]
67
+ # for ns in [1, 2, 4]
68
+ # for nw in [1, 2, 4, 8, 16]],
69
+ # key=['N'])
70
+ @triton.jit
71
+ def _grpo_loss_fwd_kernel(
72
+ LOGITS,
73
+ OLD_LOGP,
74
+ REF_LOGP,
75
+ INPUT_IDS,
76
+ COMPLETION_MASK,
77
+ ADVANTAGES,
78
+ LOSS,
79
+ LSE,
80
+ KL,
81
+ IS_CLIPPED,
82
+ TEMPERATURE,
83
+ BETA: tl.constexpr,
84
+ EPS_LOW,
85
+ EPS_HIGH,
86
+ L: tl.constexpr,
87
+ N: tl.constexpr,
88
+ BLOCK_N: tl.constexpr = 4096,
89
+ ):
90
+ off_b = tl.program_id(0).cast(tl.int64)
91
+ off_l = tl.program_id(1).cast(tl.int64)
92
+
93
+ if COMPLETION_MASK is not None:
94
+ COMPLETION_MASK += off_b * L + off_l
95
+ not_skip = tl.load(COMPLETION_MASK)
96
+ if not_skip == 0:
97
+ return
98
+
99
+ LOGITS += off_b * (L + 1) * N + off_l * N
100
+ INPUT_IDS += off_b * L + off_l
101
+ ADVANTAGES += off_b
102
+ LOSS += off_b * L + off_l
103
+ LSE += off_b * L + off_l
104
+ IS_CLIPPED += off_b * L + off_l
105
+
106
+ m_i = float("-inf")
107
+ l_i = 0.0
108
+ for start in range(0, N, BLOCK_N):
109
+ cols = start + tl.arange(0, BLOCK_N)
110
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
111
+ new_m_i = tl.maximum(m_i, tl.max(logits))
112
+ alpha = tl.exp(m_i - new_m_i)
113
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
114
+ m_i = new_m_i
115
+ lse = m_i + tl.log(l_i)
116
+
117
+ idx = tl.load(INPUT_IDS)
118
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
119
+ logp = x - lse
120
+ if OLD_LOGP is None:
121
+ old_logp = logp
122
+ else:
123
+ OLD_LOGP += off_b * L + off_l
124
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
125
+ coef_1 = tl.exp(logp - old_logp)
126
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
127
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
128
+ per_token_loss1 = coef_1 * advantage
129
+ per_token_loss2 = coef_2 * advantage
130
+ per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
+ is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132
+ is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133
+ is_clipped = is_low_clipped | is_high_clipped
134
+
135
+ if BETA != 0.0:
136
+ REF_LOGP += off_b * L + off_l
137
+ KL += off_b * L + off_l
138
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
139
+ kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
140
+ per_token_loss += BETA * kl
141
+ tl.store(KL, kl)
142
+
143
+ tl.store(LOSS, per_token_loss)
144
+ tl.store(LSE, lse)
145
+ tl.store(IS_CLIPPED, is_clipped)
146
+
147
+
148
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
149
+ # for BLOCK_N in [2048, 4096, 8192]
150
+ # for ns in [1, 2, 4]
151
+ # for nw in [1, 2, 4, 8, 16]],
152
+ # key=['N'])
153
+ @triton.jit
154
+ def _grpo_loss_bwd_kernel(
155
+ DLOSS,
156
+ DLOGITS,
157
+ LOGITS,
158
+ OLD_LOGP,
159
+ REF_LOGP,
160
+ INPUT_IDS,
161
+ ADVANTAGES,
162
+ COMPLETION_MASK,
163
+ LSE,
164
+ TEMPERATURE,
165
+ BETA: tl.constexpr,
166
+ EPS_LOW,
167
+ EPS_HIGH,
168
+ loss_stride0,
169
+ loss_stride1,
170
+ L: tl.constexpr,
171
+ N: tl.constexpr,
172
+ BLOCK_N: tl.constexpr = 4096,
173
+ ):
174
+ off_b = tl.program_id(0).cast(tl.int64)
175
+ off_l = tl.program_id(1).cast(tl.int64)
176
+
177
+ DLOGITS += off_b * (L + 1) * N + off_l * N
178
+ if COMPLETION_MASK is not None:
179
+ COMPLETION_MASK += off_b * L + off_l
180
+ not_skip = tl.load(COMPLETION_MASK)
181
+ if not_skip == 0:
182
+ for start in range(0, N, BLOCK_N):
183
+ cols = tl.arange(0, BLOCK_N) + start
184
+ tl.store(DLOGITS + cols, 0.0, mask=cols < N)
185
+ return
186
+
187
+ LOGITS += off_b * (L + 1) * N + off_l * N
188
+ DLOSS += off_b * loss_stride0 + off_l * loss_stride1
189
+ INPUT_IDS += off_b * L + off_l
190
+ ADVANTAGES += off_b
191
+ LSE += off_b * L + off_l
192
+
193
+ dloss = tl.load(DLOSS).to(tl.float32)
194
+ lse = tl.load(LSE).to(tl.float32)
195
+
196
+ idx = tl.load(INPUT_IDS)
197
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
198
+ logp = x - lse
199
+ if OLD_LOGP is None:
200
+ old_logp = logp
201
+ else:
202
+ OLD_LOGP += off_b * L + off_l
203
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
204
+ coef_1 = tl.exp(logp - old_logp)
205
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
206
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
207
+ per_token_loss1 = coef_1 * advantage
208
+ per_token_loss2 = coef_2 * advantage
209
+ mask = per_token_loss2 >= per_token_loss1
210
+
211
+ dlogp = -per_token_loss1 * mask
212
+ if BETA != 0.0:
213
+ REF_LOGP += off_b * L + off_l
214
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
215
+ dlogp += BETA * (1 - tl.exp(ref_logp - logp))
216
+
217
+ dlogp = dlogp * dloss / TEMPERATURE
218
+ tl.debug_barrier()
219
+ for start_n in tl.range(0, N, BLOCK_N):
220
+ cols = start_n + tl.arange(0, BLOCK_N)
221
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
222
+ probs = tl.exp(logits - lse)
223
+ dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
224
+ tl.store(DLOGITS + cols, dlogits, mask=cols < N)
225
+
226
+
227
+ class GrpoLossFunction(torch.autograd.Function):
228
+ @staticmethod
229
+ def forward(
230
+ ctx,
231
+ logits,
232
+ old_logp,
233
+ ref_logp,
234
+ completion_ids,
235
+ advantages,
236
+ completion_mask,
237
+ temperature,
238
+ beta,
239
+ eps_low,
240
+ eps_high,
241
+ inplace,
242
+ ):
243
+ assert logits.is_contiguous() and completion_ids.is_contiguous()
244
+ assert old_logp is None or old_logp.is_contiguous()
245
+ assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
246
+
247
+ B, L_ADD_1, N = logits.shape
248
+ L = L_ADD_1 - 1
249
+
250
+ if completion_mask is not None:
251
+ assert completion_mask.is_contiguous()
252
+
253
+ loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
254
+ lse = torch.zeros_like(loss)
255
+ is_clipped = torch.zeros_like(loss)
256
+ kl = torch.zeros_like(loss) if beta != 0.0 else None
257
+ kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
258
+ _grpo_loss_fwd_kernel[(B, L)](
259
+ logits,
260
+ old_logp,
261
+ ref_logp,
262
+ completion_ids,
263
+ completion_mask,
264
+ advantages,
265
+ loss,
266
+ lse,
267
+ kl,
268
+ is_clipped,
269
+ temperature,
270
+ beta,
271
+ eps_low,
272
+ eps_high,
273
+ L,
274
+ N,
275
+ **kwargs,
276
+ )
277
+ ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
278
+ ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
279
+ # return loss
280
+ return loss, kl, is_clipped
281
+
282
+ @staticmethod
283
+ def backward(ctx, *args):
284
+ dloss = args[0]
285
+ # print(dloss.shape)
286
+ logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
287
+ temperature, beta, eps_low, eps_high, inplace = ctx.infos
288
+ B, L_ADD_1, N = logits.shape
289
+ L = L_ADD_1 - 1
290
+ dlogits = logits.data if inplace else torch.empty_like(logits)
291
+ kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
292
+ _grpo_loss_bwd_kernel[(B, L)](
293
+ dloss,
294
+ dlogits,
295
+ logits,
296
+ old_logp,
297
+ ref_logp,
298
+ completion_ids,
299
+ advantages,
300
+ completion_mask,
301
+ lse,
302
+ temperature,
303
+ beta,
304
+ eps_low,
305
+ eps_high,
306
+ *dloss.stride(),
307
+ L,
308
+ N,
309
+ **kwargs,
310
+ )
311
+ dlogits[:, -1, :] = 0
312
+ return dlogits, None, None, None, None, None, None, None, None, None, None
@@ -0,0 +1,201 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.utils import infer_device
9
+
10
+
11
+ @triton.jit
12
+ def _jsd_kernel(
13
+ X_ptr, # input in logspace, X = log Q
14
+ X_stride,
15
+ Y_ptr, # ground truth in logspace, Y = log P
16
+ Y_stride,
17
+ loss_ptr,
18
+ loss_stride,
19
+ dX_ptr,
20
+ dX_stride,
21
+ label_ptr,
22
+ beta: tl.constexpr,
23
+ n_non_ignore: int,
24
+ ignore_index: tl.constexpr,
25
+ n_cols,
26
+ BLOCK_SIZE: tl.constexpr,
27
+ HAS_LABEL: tl.constexpr,
28
+ ):
29
+ # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
30
+ # = sum(P * log P + Q * log Q - 2 * M * log M) / 2
31
+ # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
32
+ # grad_x_i = 0.5 * Q * (X - log_M)
33
+ pid = tl.program_id(0).to(tl.int64)
34
+ X_ptr += pid * X_stride
35
+ dX_ptr += pid * dX_stride
36
+ Y_ptr += pid * Y_stride
37
+ loss_ptr += pid * loss_stride
38
+ label_ptr += pid
39
+
40
+ if HAS_LABEL:
41
+ label = tl.load(label_ptr)
42
+ if label == ignore_index:
43
+ for i in range(0, n_cols, BLOCK_SIZE):
44
+ offsets = i + tl.arange(0, BLOCK_SIZE)
45
+ tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
46
+ return
47
+
48
+ for i in range(0, n_cols, BLOCK_SIZE):
49
+ offsets = i + tl.arange(0, BLOCK_SIZE)
50
+ mask = offsets < n_cols
51
+ X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
+ Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
53
+
54
+ if beta == 0.0: # forward KL
55
+ Y_max = tl.max(Y, axis=0)
56
+ Y_shifted = Y - Y_max
57
+ Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
58
+ loss = Y_prob * (Y - X)
59
+ dX = -Y_prob
60
+ elif beta == 1.0: # reverse KL
61
+ X_max = tl.max(X, axis=0)
62
+ X_shifted = X - X_max
63
+ X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
64
+ loss = X_prob * (X - Y)
65
+ dX = loss + X_prob
66
+ else:
67
+ max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
68
+ X_shifted = X - max_val
69
+ Y_shifted = Y - max_val
70
+
71
+ # Pre-compute exp(max_val) since it's used twice
72
+ exp_max = tl.exp(max_val)
73
+
74
+ # Compute exp terms with compensation
75
+ Q = tl.exp(X_shifted) * exp_max # = exp(X)
76
+ P = tl.exp(Y_shifted) * exp_max # = exp(Y)
77
+
78
+ # Pre-compute common terms
79
+ beta_P = beta * P
80
+ one_minus_beta_Q = (1 - beta) * Q
81
+ M = beta_P + one_minus_beta_Q
82
+ log_M = tl.log(M) # No need to compensate as M is already in original scale
83
+
84
+ loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
85
+ dX = one_minus_beta_Q * (X - log_M)
86
+
87
+ # Pre-compute scaling factor
88
+ scale = 1.0 / n_non_ignore
89
+ loss = loss * scale
90
+ dX = dX * scale
91
+
92
+ tl.store(loss_ptr + offsets, loss, mask=mask)
93
+ tl.store(dX_ptr + offsets, dX, mask=mask)
94
+
95
+
96
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
97
+
98
+
99
+ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
100
+ BT, V = _input.shape
101
+ n_rows = BT
102
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
103
+ # non reduction loss
104
+ loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
105
+ dX = torch.empty_like(_input)
106
+
107
+ if has_label:
108
+ n_non_ignore = (shift_labels != ignore_index).sum().item()
109
+ else:
110
+ n_non_ignore = BT
111
+
112
+ _jsd_kernel[(n_rows,)](
113
+ X_ptr=_input, # input in logspace, X = log Q
114
+ X_stride=_input.stride(-2),
115
+ Y_ptr=target, # ground truth in logspace, Y = log P
116
+ Y_stride=target.stride(-2),
117
+ loss_ptr=loss,
118
+ loss_stride=loss.stride(-2),
119
+ dX_ptr=dX,
120
+ dX_stride=dX.stride(-2),
121
+ label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
122
+ beta=beta,
123
+ n_non_ignore=n_non_ignore,
124
+ ignore_index=ignore_index,
125
+ n_cols=V,
126
+ BLOCK_SIZE=BLOCK_SIZE,
127
+ HAS_LABEL=has_label,
128
+ )
129
+
130
+ loss = torch.sum(loss)
131
+ return loss.to(_input.dtype), dX
132
+
133
+
134
+ def jsd_backward(dX, grad_output):
135
+ # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
136
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
137
+ return dX
138
+ else:
139
+ return grad_output * dX
140
+
141
+
142
+ class LigerJSDFunction(torch.autograd.Function):
143
+ r"""
144
+ This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
145
+ .. math::
146
+ JSD(\beta)(P || Q)
147
+ = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
148
+
149
+ .. note::
150
+ As all the other losses in PyTorch, this function expects the first argument,
151
+ :attr:`_input`, to be the predictions, the output of the student model, in log-space
152
+ and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
153
+ This differs from the standard mathematical notation :math:`JSD(P || Q)` where
154
+ :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
155
+ """
156
+
157
+ @staticmethod
158
+ @ensure_contiguous
159
+ def forward(
160
+ ctx,
161
+ _input: torch.Tensor,
162
+ target: torch.Tensor,
163
+ shift_labels: Optional[torch.Tensor] = None,
164
+ beta: float = 0.5,
165
+ ignore_index: int = -100,
166
+ ) -> torch.Tensor:
167
+ """
168
+ Args:
169
+ _input (torch.Tensor): predict values with shape (BT, V) in logspace
170
+ target (torch.Tensor): ground truth values with shape (BT, V) in logspace
171
+ shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
172
+ beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
173
+ ignore_index (int): the index to ignore. Default: -100
174
+
175
+ Returns:
176
+ loss (torch.Tensor): generalized JSD
177
+ """
178
+ has_label = False
179
+ if shift_labels is not None:
180
+ assert shift_labels.shape == (_input.shape[0],), (
181
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
182
+ )
183
+ shift_labels = shift_labels.contiguous()
184
+ has_label = True
185
+
186
+ loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
187
+ ctx.save_for_backward(dX)
188
+ return loss
189
+
190
+ @staticmethod
191
+ @ensure_contiguous
192
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
193
+ (dX,) = ctx.saved_tensors
194
+ dX = jsd_backward(dX, grad_output)
195
+ return (
196
+ dX,
197
+ None,
198
+ None,
199
+ None,
200
+ None,
201
+ )