liger-kernel-nightly 0.5.6.dev20250407214804__py3-none-any.whl → 0.5.6.dev20250408182156__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -154,6 +154,11 @@ def layer_norm_forward(X, W, B, eps):
154
154
  f"must match weight size (W.shape[0]={W.shape[0]})"
155
155
  )
156
156
 
157
+ # XPU-specific optimization
158
+ kernel_args = {}
159
+ if X.device.type == "xpu":
160
+ kernel_args["grf_mode"] = "large"
161
+
157
162
  _layer_norm_forward_kernel[(n_rows,)](
158
163
  Y,
159
164
  Y.stride(0),
@@ -171,6 +176,7 @@ def layer_norm_forward(X, W, B, eps):
171
176
  eps,
172
177
  BLOCK_SIZE=BLOCK_SIZE,
173
178
  num_warps=num_warps,
179
+ **kernel_args, # XPU-specific optimization
174
180
  )
175
181
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
176
182
 
@@ -185,7 +191,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
185
191
  if X.device.type == "cuda":
186
192
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
187
193
  elif X.device.type == "xpu":
188
- sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
194
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
189
195
 
190
196
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
191
197
  _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
@@ -208,6 +214,12 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
208
214
  if X.dtype == torch.float16
209
215
  else tl.float32 # fallback to float32 for other types
210
216
  )
217
+
218
+ # XPU-specific optimization
219
+ kernel_args = {}
220
+ if X.device.type == "xpu":
221
+ kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
+
211
223
  _layer_norm_backward_kernel[grid](
212
224
  X,
213
225
  W,
@@ -227,6 +239,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
227
239
  rows_per_program,
228
240
  BLOCK_SIZE=BLOCK_SIZE,
229
241
  dtype=triton_dtype,
242
+ **kernel_args, # XPU-specific optimization
230
243
  )
231
244
 
232
245
  DW = _DW.sum(dim=0).to(W.dtype)
@@ -223,6 +223,10 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
223
223
  # Check constraints.
224
224
  assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
225
225
 
226
+ # XPU-specific optimization
227
+ kernel_args = {}
228
+ if X.device.type == "xpu":
229
+ kernel_args["grf_mode"] = "large"
226
230
  _rms_norm_forward_kernel[(n_rows,)](
227
231
  Y,
228
232
  Y.stride(0),
@@ -238,6 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
238
242
  casting_mode,
239
243
  BLOCK_SIZE=BLOCK_SIZE,
240
244
  num_warps=num_warps,
245
+ **kernel_args, # XPU-specific optimization
241
246
  )
242
247
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
243
248
 
@@ -252,7 +257,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
252
257
  if X.device.type == "cuda":
253
258
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
254
259
  elif X.device.type == "xpu":
255
- sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
260
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
256
261
 
257
262
  # fp32 for numerical stability especially.
258
263
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -267,6 +272,11 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
267
272
  else:
268
273
  dX = torch.zeros_like(dY)
269
274
 
275
+ # XPU-specific optimization
276
+ kernel_args = {}
277
+ if X.device.type == "xpu":
278
+ kernel_args["grf_mode"] = "large"
279
+
270
280
  _rms_norm_backward_kernel[grid](
271
281
  dY,
272
282
  dY.stride(0),
@@ -288,6 +298,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
288
298
  casting_mode,
289
299
  BLOCK_SIZE=BLOCK_SIZE,
290
300
  num_warps=num_warps,
301
+ **kernel_args, # XPU-specific optimization
291
302
  )
292
303
  dX = dX.view(*shape)
293
304
  dW = _dW.sum(dim=0).to(W.dtype)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.6.dev20250407214804
3
+ Version: 0.5.6.dev20250408182156
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -24,9 +24,9 @@ liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,412
24
24
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
25
25
  liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
26
26
  liger_kernel/ops/kl_div.py,sha256=NkG7D6_DnPBzr-ohhYiQbRBnq_fbGmpn5UU7y0UBKQo,8420
27
- liger_kernel/ops/layer_norm.py,sha256=6roQjioyg-9O2qLPV8nL4U0-5UH80tdzOMTWwjvDnn8,7961
27
+ liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
28
28
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
29
- liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
29
+ liger_kernel/ops/rms_norm.py,sha256=PP27OIBmV9By63i13jot9ylDowW0nuxY_JFIkaPLgL4,12078
30
30
  liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
31
31
  liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
32
32
  liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
74
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel_nightly-0.5.6.dev20250407214804.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel_nightly-0.5.6.dev20250407214804.dist-info/METADATA,sha256=0lQVqhPNaqGVZvOrb6MxIp2eP7IYoABa4llfB8Ua868,23297
79
- liger_kernel_nightly-0.5.6.dev20250407214804.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
- liger_kernel_nightly-0.5.6.dev20250407214804.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
- liger_kernel_nightly-0.5.6.dev20250407214804.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel_nightly-0.5.6.dev20250407214804.dist-info/RECORD,,
77
+ liger_kernel_nightly-0.5.6.dev20250408182156.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel_nightly-0.5.6.dev20250408182156.dist-info/METADATA,sha256=SP0FXayK2-JFayGwAcDBEbRk3PGmGqZVGCZw_PBG3jg,23297
79
+ liger_kernel_nightly-0.5.6.dev20250408182156.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
+ liger_kernel_nightly-0.5.6.dev20250408182156.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
+ liger_kernel_nightly-0.5.6.dev20250408182156.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel_nightly-0.5.6.dev20250408182156.dist-info/RECORD,,