liger-kernel-nightly 0.4.2.dev20241122175637__py3-none-any.whl → 0.4.2.dev20241123040418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -180,8 +180,13 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
180
180
  dY = dY.view(-1, dim)
181
181
  n_rows, n_cols = dY.shape
182
182
 
183
+ sm_count = 1
184
+ if X.device.type == "cuda":
185
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
186
+ elif X.device.type == "xpu":
187
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
188
+
183
189
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
184
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
185
190
  _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
186
191
  _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
187
192
 
@@ -264,6 +264,7 @@ def rms_norm_backward(
264
264
  dY = dY.view(-1, dim)
265
265
  n_rows, n_cols = dY.shape
266
266
 
267
+ sm_count = 1
267
268
  if X.device.type == "cuda":
268
269
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
269
270
  elif X.device.type == "xpu":
liger_kernel/ops/utils.py CHANGED
@@ -20,6 +20,8 @@ import triton
20
20
  import triton.language as tl
21
21
  from packaging.version import Version
22
22
 
23
+ from liger_kernel.utils import infer_device
24
+
23
25
 
24
26
  def is_hip() -> bool:
25
27
  return torch.version.hip is not None
@@ -69,10 +71,11 @@ def compare_version(package: str, operator: Callable, target: str):
69
71
 
70
72
 
71
73
  def get_amp_custom_fwd_bwd() -> Callable:
74
+ device = infer_device()
72
75
  if compare_version("torch", operator.ge, "2.4.0"):
73
76
  return (
74
- functools.partial(torch.amp.custom_fwd, device_type="cuda"),
75
- functools.partial(torch.amp.custom_bwd, device_type="cuda"),
77
+ functools.partial(torch.amp.custom_fwd, device_type=device),
78
+ functools.partial(torch.amp.custom_bwd, device_type=device),
76
79
  )
77
80
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
78
81
 
liger_kernel/utils.py ADDED
@@ -0,0 +1,13 @@
1
+ import torch
2
+
3
+
4
+ def infer_device():
5
+ """
6
+ Get current device name based on available devices
7
+ """
8
+ if torch.cuda.is_available():
9
+ return "cuda"
10
+ elif torch.xpu.is_available():
11
+ return "xpu"
12
+ else:
13
+ return "cpu"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241122175637
3
+ Version: 0.4.2.dev20241123040418
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -1,4 +1,6 @@
1
+ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
1
2
  liger_kernel/env_report.py,sha256=jye8RvUkmhqaIshdeIpoUABoAu7FPKJUib4FnAfvkpw,1132
3
+ liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
2
4
  liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
3
5
  liger_kernel/chunked_loss/cpo_loss.py,sha256=H2L6mNtU8RMJ17u4aMZ9FHEfBvg1Z_hliY5-jZxiDBM,3079
4
6
  liger_kernel/chunked_loss/dpo_loss.py,sha256=XcCGLVmTVdEX30q41XRXXK_c-MSumVJ-l4tQwobUv2w,4228
@@ -14,12 +16,12 @@ liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,412
14
16
  liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d8,10961
15
17
  liger_kernel/ops/jsd.py,sha256=Ap2b0_geCl6fqBXLI1IS6Yn6GlO-8LgPmnOW3y47dus,6151
16
18
  liger_kernel/ops/kl_div.py,sha256=03FNXfvCb6M-56hhFepAFV9p6brArPR6KOKkdGD34mw,8374
17
- liger_kernel/ops/layer_norm.py,sha256=unGMYMOPqtkM9aTrokhcqgPmsV2AUN7Yzv86isVB9OI,7422
19
+ liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
18
20
  liger_kernel/ops/qwen2vl_mrope.py,sha256=xZvQnhkSTjU-k6KiiRn9e0SYO1ESs1jmuZFMICduLpc,8552
19
- liger_kernel/ops/rms_norm.py,sha256=GKs49wXmUngY7MJ5QDQxTp4P2HDVqzZBaCr0pGtyZyM,11733
21
+ liger_kernel/ops/rms_norm.py,sha256=g7OXwuYI8-LXudDwvXuiupVjjOsbu8c4wwv83VaHa54,11750
20
22
  liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
21
23
  liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
22
- liger_kernel/ops/utils.py,sha256=3JSF--O7KT5Wa5BuO70M4h0XetxoZ_e9IoW9GRlxlBg,3777
24
+ liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
23
25
  liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
24
26
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
25
27
  liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
@@ -52,9 +54,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
52
54
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
53
55
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
54
56
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
55
- liger_kernel_nightly-0.4.2.dev20241122175637.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
56
- liger_kernel_nightly-0.4.2.dev20241122175637.dist-info/METADATA,sha256=jxja3ZGNuVc_U6JWdjw23O6IdeKK5h45_X16xh-e6xc,21891
57
- liger_kernel_nightly-0.4.2.dev20241122175637.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
58
- liger_kernel_nightly-0.4.2.dev20241122175637.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
59
- liger_kernel_nightly-0.4.2.dev20241122175637.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
60
- liger_kernel_nightly-0.4.2.dev20241122175637.dist-info/RECORD,,
57
+ liger_kernel_nightly-0.4.2.dev20241123040418.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
58
+ liger_kernel_nightly-0.4.2.dev20241123040418.dist-info/METADATA,sha256=lXW5-kGkMAutfiUrZflzYeW1bZo8efp65MDohQ3G1T0,21891
59
+ liger_kernel_nightly-0.4.2.dev20241123040418.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
60
+ liger_kernel_nightly-0.4.2.dev20241123040418.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
61
+ liger_kernel_nightly-0.4.2.dev20241123040418.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
62
+ liger_kernel_nightly-0.4.2.dev20241123040418.dist-info/RECORD,,