liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py CHANGED
@@ -7,8 +7,9 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import calculate_settings
8
8
  from liger_kernel.ops.utils import compare_version
9
9
  from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import is_npu_available
10
11
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
13
  try:
13
14
  # typical import path with dispatch available
14
15
  from triton.language.extra.libdevice import tanh
@@ -40,7 +41,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
40
41
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
42
  tanh_result = tanh(tanh_arg)
42
43
  geglu_a = 0.5 * a_row * (1 + tanh_result)
43
- c_row = geglu_a * b_row
44
+ c_row = geglu_a.cast(b_row.dtype) * b_row
44
45
  tl.store(c + col_offsets, c_row, mask=mask)
45
46
 
46
47
 
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
66
67
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
67
68
  tanh_result = tanh(tanh_arg)
68
69
  geglu_a = 0.5 * a_row * (1 + tanh_result)
70
+ geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
69
71
 
70
- db_row = dc_row * geglu_a
72
+ db_row = dc_row.cast(tl.float32) * geglu_a
71
73
 
72
74
  # Gradient w.r.t. a can be computed with:
73
75
  # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
78
80
  da_row = dc_row * b_row * (term1 + term2)
79
81
 
80
82
  tl.store(a + col_offsets, da_row, mask=mask)
81
- tl.store(b + col_offsets, db_row, mask=mask)
83
+ tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
82
84
 
83
85
 
84
86
  def geglu_forward(a, b):
@@ -6,8 +6,9 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.utils import is_npu_available
9
10
 
10
- if compare_version("triton", operator.ge, "3.0.0"):
11
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
11
12
  try:
12
13
  # typical import path with dispatch available
13
14
  from triton.language.extra.libdevice import rsqrt
@@ -77,15 +78,14 @@ def _group_norm_forward_kernel(
77
78
  for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
78
79
  W = tl.load(W_ptr + channel_idx)
79
80
  B = tl.load(B_ptr + channel_idx)
80
- for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
81
+ # Calculate channel offset within the group
82
+ channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
83
+ for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
81
84
  hidden_size_offsets = i + block_range
82
85
  mask = hidden_size_offsets < hidden_size_per_channel
83
- X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
86
+ X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
84
87
  Y = (X - m) * rstd * W + B
85
- tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
86
-
87
- X_ptr += hidden_size_per_channel
88
- Y_ptr += hidden_size_per_channel
88
+ tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
89
89
 
90
90
  tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
91
91
  tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
@@ -0,0 +1,312 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _selective_log_softmax_kernel(
8
+ LOGITS,
9
+ INPUT_IDS,
10
+ LOG_P,
11
+ MASK,
12
+ TEMPERATURE,
13
+ stride_input_ids_b,
14
+ L: tl.constexpr,
15
+ N: tl.constexpr,
16
+ BLOCK_N: tl.constexpr = 4096,
17
+ ):
18
+ off_b = tl.program_id(0).cast(tl.int64)
19
+ off_l = tl.program_id(1).cast(tl.int64)
20
+
21
+ LOGITS += off_b * (L + 1) * N + off_l * N
22
+ INPUT_IDS += off_b * stride_input_ids_b + off_l
23
+ LOG_P += off_b * L + off_l
24
+
25
+ if MASK is not None:
26
+ MASK += off_b * stride_input_ids_b + off_l
27
+ not_skip = tl.load(MASK)
28
+ if not_skip == 0:
29
+ return
30
+
31
+ m_i = float("-inf")
32
+ l_i = 0.0
33
+ for start in range(0, N, BLOCK_N):
34
+ cols = start + tl.arange(0, BLOCK_N)
35
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
36
+ new_m_i = tl.maximum(m_i, tl.max(logits))
37
+ alpha = tl.exp(m_i - new_m_i)
38
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
39
+ m_i = new_m_i
40
+ lse = m_i + tl.log(l_i)
41
+
42
+ ids = tl.load(INPUT_IDS)
43
+ x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
44
+ logp = x - lse
45
+ tl.store(LOG_P, logp)
46
+
47
+
48
+ # compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
49
+ @torch.no_grad
50
+ def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
51
+ assert logits.is_contiguous()
52
+ B, L_ADD_1, N = logits.shape
53
+ L = L_ADD_1 - 1
54
+ input_ids = input_ids[:, -L:]
55
+ if mask is not None:
56
+ mask = mask[:, -L:]
57
+ log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
58
+ kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
59
+ _selective_log_softmax_kernel[(B, L)](
60
+ logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
61
+ )
62
+ return log_p
63
+
64
+
65
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
66
+ # for BLOCK_N in [2048, 4096, 8192]
67
+ # for ns in [1, 2, 4]
68
+ # for nw in [1, 2, 4, 8, 16]],
69
+ # key=['N'])
70
+ @triton.jit
71
+ def _grpo_loss_fwd_kernel(
72
+ LOGITS,
73
+ OLD_LOGP,
74
+ REF_LOGP,
75
+ INPUT_IDS,
76
+ COMPLETION_MASK,
77
+ ADVANTAGES,
78
+ LOSS,
79
+ LSE,
80
+ KL,
81
+ IS_CLIPPED,
82
+ TEMPERATURE,
83
+ BETA: tl.constexpr,
84
+ EPS_LOW,
85
+ EPS_HIGH,
86
+ L: tl.constexpr,
87
+ N: tl.constexpr,
88
+ BLOCK_N: tl.constexpr = 4096,
89
+ ):
90
+ off_b = tl.program_id(0).cast(tl.int64)
91
+ off_l = tl.program_id(1).cast(tl.int64)
92
+
93
+ if COMPLETION_MASK is not None:
94
+ COMPLETION_MASK += off_b * L + off_l
95
+ not_skip = tl.load(COMPLETION_MASK)
96
+ if not_skip == 0:
97
+ return
98
+
99
+ LOGITS += off_b * (L + 1) * N + off_l * N
100
+ INPUT_IDS += off_b * L + off_l
101
+ ADVANTAGES += off_b
102
+ LOSS += off_b * L + off_l
103
+ LSE += off_b * L + off_l
104
+ IS_CLIPPED += off_b * L + off_l
105
+
106
+ m_i = float("-inf")
107
+ l_i = 0.0
108
+ for start in range(0, N, BLOCK_N):
109
+ cols = start + tl.arange(0, BLOCK_N)
110
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
111
+ new_m_i = tl.maximum(m_i, tl.max(logits))
112
+ alpha = tl.exp(m_i - new_m_i)
113
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
114
+ m_i = new_m_i
115
+ lse = m_i + tl.log(l_i)
116
+
117
+ idx = tl.load(INPUT_IDS)
118
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
119
+ logp = x - lse
120
+ if OLD_LOGP is None:
121
+ old_logp = logp
122
+ else:
123
+ OLD_LOGP += off_b * L + off_l
124
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
125
+ coef_1 = tl.exp(logp - old_logp)
126
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
127
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
128
+ per_token_loss1 = coef_1 * advantage
129
+ per_token_loss2 = coef_2 * advantage
130
+ per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
+ is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132
+ is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133
+ is_clipped = is_low_clipped | is_high_clipped
134
+
135
+ if BETA != 0.0:
136
+ REF_LOGP += off_b * L + off_l
137
+ KL += off_b * L + off_l
138
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
139
+ kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
140
+ per_token_loss += BETA * kl
141
+ tl.store(KL, kl)
142
+
143
+ tl.store(LOSS, per_token_loss)
144
+ tl.store(LSE, lse)
145
+ tl.store(IS_CLIPPED, is_clipped)
146
+
147
+
148
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
149
+ # for BLOCK_N in [2048, 4096, 8192]
150
+ # for ns in [1, 2, 4]
151
+ # for nw in [1, 2, 4, 8, 16]],
152
+ # key=['N'])
153
+ @triton.jit
154
+ def _grpo_loss_bwd_kernel(
155
+ DLOSS,
156
+ DLOGITS,
157
+ LOGITS,
158
+ OLD_LOGP,
159
+ REF_LOGP,
160
+ INPUT_IDS,
161
+ ADVANTAGES,
162
+ COMPLETION_MASK,
163
+ LSE,
164
+ TEMPERATURE,
165
+ BETA: tl.constexpr,
166
+ EPS_LOW,
167
+ EPS_HIGH,
168
+ loss_stride0,
169
+ loss_stride1,
170
+ L: tl.constexpr,
171
+ N: tl.constexpr,
172
+ BLOCK_N: tl.constexpr = 4096,
173
+ ):
174
+ off_b = tl.program_id(0).cast(tl.int64)
175
+ off_l = tl.program_id(1).cast(tl.int64)
176
+
177
+ DLOGITS += off_b * (L + 1) * N + off_l * N
178
+ if COMPLETION_MASK is not None:
179
+ COMPLETION_MASK += off_b * L + off_l
180
+ not_skip = tl.load(COMPLETION_MASK)
181
+ if not_skip == 0:
182
+ for start in range(0, N, BLOCK_N):
183
+ cols = tl.arange(0, BLOCK_N) + start
184
+ tl.store(DLOGITS + cols, 0.0, mask=cols < N)
185
+ return
186
+
187
+ LOGITS += off_b * (L + 1) * N + off_l * N
188
+ DLOSS += off_b * loss_stride0 + off_l * loss_stride1
189
+ INPUT_IDS += off_b * L + off_l
190
+ ADVANTAGES += off_b
191
+ LSE += off_b * L + off_l
192
+
193
+ dloss = tl.load(DLOSS).to(tl.float32)
194
+ lse = tl.load(LSE).to(tl.float32)
195
+
196
+ idx = tl.load(INPUT_IDS)
197
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
198
+ logp = x - lse
199
+ if OLD_LOGP is None:
200
+ old_logp = logp
201
+ else:
202
+ OLD_LOGP += off_b * L + off_l
203
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
204
+ coef_1 = tl.exp(logp - old_logp)
205
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
206
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
207
+ per_token_loss1 = coef_1 * advantage
208
+ per_token_loss2 = coef_2 * advantage
209
+ mask = per_token_loss2 >= per_token_loss1
210
+
211
+ dlogp = -per_token_loss1 * mask
212
+ if BETA != 0.0:
213
+ REF_LOGP += off_b * L + off_l
214
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
215
+ dlogp += BETA * (1 - tl.exp(ref_logp - logp))
216
+
217
+ dlogp = dlogp * dloss / TEMPERATURE
218
+ tl.debug_barrier()
219
+ for start_n in tl.range(0, N, BLOCK_N):
220
+ cols = start_n + tl.arange(0, BLOCK_N)
221
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
222
+ probs = tl.exp(logits - lse)
223
+ dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
224
+ tl.store(DLOGITS + cols, dlogits, mask=cols < N)
225
+
226
+
227
+ class GrpoLossFunction(torch.autograd.Function):
228
+ @staticmethod
229
+ def forward(
230
+ ctx,
231
+ logits,
232
+ old_logp,
233
+ ref_logp,
234
+ completion_ids,
235
+ advantages,
236
+ completion_mask,
237
+ temperature,
238
+ beta,
239
+ eps_low,
240
+ eps_high,
241
+ inplace,
242
+ ):
243
+ assert logits.is_contiguous() and completion_ids.is_contiguous()
244
+ assert old_logp is None or old_logp.is_contiguous()
245
+ assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
246
+
247
+ B, L_ADD_1, N = logits.shape
248
+ L = L_ADD_1 - 1
249
+
250
+ if completion_mask is not None:
251
+ assert completion_mask.is_contiguous()
252
+
253
+ loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
254
+ lse = torch.zeros_like(loss)
255
+ is_clipped = torch.zeros_like(loss)
256
+ kl = torch.zeros_like(loss) if beta != 0.0 else None
257
+ kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
258
+ _grpo_loss_fwd_kernel[(B, L)](
259
+ logits,
260
+ old_logp,
261
+ ref_logp,
262
+ completion_ids,
263
+ completion_mask,
264
+ advantages,
265
+ loss,
266
+ lse,
267
+ kl,
268
+ is_clipped,
269
+ temperature,
270
+ beta,
271
+ eps_low,
272
+ eps_high,
273
+ L,
274
+ N,
275
+ **kwargs,
276
+ )
277
+ ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
278
+ ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
279
+ # return loss
280
+ return loss, kl, is_clipped
281
+
282
+ @staticmethod
283
+ def backward(ctx, *args):
284
+ dloss = args[0]
285
+ # print(dloss.shape)
286
+ logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
287
+ temperature, beta, eps_low, eps_high, inplace = ctx.infos
288
+ B, L_ADD_1, N = logits.shape
289
+ L = L_ADD_1 - 1
290
+ dlogits = logits.data if inplace else torch.empty_like(logits)
291
+ kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
292
+ _grpo_loss_bwd_kernel[(B, L)](
293
+ dloss,
294
+ dlogits,
295
+ logits,
296
+ old_logp,
297
+ ref_logp,
298
+ completion_ids,
299
+ advantages,
300
+ completion_mask,
301
+ lse,
302
+ temperature,
303
+ beta,
304
+ eps_low,
305
+ eps_high,
306
+ *dloss.stride(),
307
+ L,
308
+ N,
309
+ **kwargs,
310
+ )
311
+ dlogits[:, -1, :] = 0
312
+ return dlogits, None, None, None, None, None, None, None, None, None, None
liger_kernel/ops/jsd.py CHANGED
@@ -5,6 +5,7 @@ import triton
5
5
  import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.utils import infer_device
8
9
 
9
10
 
10
11
  @triton.jit
@@ -92,7 +93,7 @@ def _jsd_kernel(
92
93
  tl.store(dX_ptr + offsets, dX, mask=mask)
93
94
 
94
95
 
95
- MAX_FUSED_SIZE = 65536
96
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
96
97
 
97
98
 
98
99
  def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
@@ -6,6 +6,7 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import ensure_contiguous
8
8
  from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
9
10
 
10
11
 
11
12
  def get_num_warps(BLOCK_SIZE):
@@ -20,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
20
21
  return num_warps
21
22
 
22
23
 
23
- MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
24
+ if infer_device() == "xpu":
25
+ MAX_FUSED_SIZE = 8192
26
+ elif infer_device() == "npu":
27
+ MAX_FUSED_SIZE = 8192
28
+ else:
29
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
24
30
 
25
31
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
26
32
 
@@ -115,9 +121,8 @@ def _kldiv_kernel_backward(
115
121
 
116
122
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
117
123
  BT, V = y_pred.shape
118
-
119
124
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
120
- num_warps = get_num_warps(BLOCK_SIZE)
125
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
121
126
 
122
127
  grid = (BT,)
123
128
  reduction = _str_to_reduction_mode[reduction]
@@ -155,9 +160,8 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
155
160
 
156
161
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
157
162
  BT, V = target.shape
158
-
159
163
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
160
- num_warps = get_num_warps(BLOCK_SIZE)
164
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
161
165
 
162
166
  grid = (BT,)
163
167