kernels 0.12.3__tar.gz → 0.14.0.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kernels-0.12.3 → kernels-0.14.0.dev0}/PKG-INFO +7 -5
- {kernels-0.12.3 → kernels-0.14.0.dev0}/pyproject.toml +15 -11
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/__init__.py +4 -4
- kernels-0.14.0.dev0/src/kernels/_versions.py +72 -0
- kernels-0.14.0.dev0/src/kernels/backends.py +276 -0
- kernels-0.14.0.dev0/src/kernels/benchmark.py +38 -0
- kernels-0.14.0.dev0/src/kernels/benchmarks/activation.py +85 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/benchmarks/attention.py +11 -33
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/benchmarks/layer_norm.py +3 -9
- kernels-0.12.3/src/kernels/cli.py → kernels-0.14.0.dev0/src/kernels/cli/__init__.py +32 -85
- {kernels-0.12.3/src/kernels → kernels-0.14.0.dev0/src/kernels/cli}/benchmark.py +217 -156
- kernels-0.14.0.dev0/src/kernels/cli/benchmark_graphics.py +740 -0
- {kernels-0.12.3/src/kernels → kernels-0.14.0.dev0/src/kernels/cli}/check.py +15 -8
- kernels-0.14.0.dev0/src/kernels/cli/versions.py +30 -0
- kernels-0.14.0.dev0/src/kernels/compat.py +14 -0
- kernels-0.14.0.dev0/src/kernels/deps.py +100 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/device.py +19 -8
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/func.py +5 -10
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/globals.py +1 -3
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/kernelize.py +6 -15
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/layer.py +16 -43
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/repos.py +14 -33
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/lockfile.py +35 -19
- kernels-0.14.0.dev0/src/kernels/metadata.py +44 -0
- kernels-0.14.0.dev0/src/kernels/status.py +81 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/utils.py +258 -209
- kernels-0.14.0.dev0/src/kernels/variants.py +452 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/PKG-INFO +7 -5
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/SOURCES.txt +12 -8
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/requires.txt +5 -3
- {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_basic.py +97 -9
- {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_deps.py +1 -3
- {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_doctest.py +1 -3
- {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_func.py +1 -1
- {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_interval_tree.py +3 -9
- kernels-0.14.0.dev0/tests/test_kernel_locking.py +100 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_layer.py +63 -157
- kernels-0.14.0.dev0/tests/test_status.py +94 -0
- kernels-0.14.0.dev0/tests/test_tvm_ffi.py +47 -0
- kernels-0.14.0.dev0/tests/test_user_agent.py +71 -0
- kernels-0.14.0.dev0/tests/test_variants.py +383 -0
- kernels-0.12.3/src/kernels/_vendored/convert_rst_to_mdx.py +0 -751
- kernels-0.12.3/src/kernels/_versions.py +0 -95
- kernels-0.12.3/src/kernels/benchmarks/activation.py +0 -44
- kernels-0.12.3/src/kernels/compat.py +0 -8
- kernels-0.12.3/src/kernels/deps.py +0 -59
- kernels-0.12.3/src/kernels/doc.py +0 -242
- kernels-0.12.3/src/kernels/metadata.py +0 -37
- kernels-0.12.3/src/kernels/upload.py +0 -82
- kernels-0.12.3/src/kernels/variants.py +0 -3
- kernels-0.12.3/src/kernels/versions_cli.py +0 -38
- kernels-0.12.3/tests/test_kernel_locking.py +0 -208
- kernels-0.12.3/tests/test_kernel_upload.py +0 -121
- {kernels-0.12.3 → kernels-0.14.0.dev0}/README.md +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/setup.cfg +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/_system.py +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/_windows.py +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/benchmarks/__init__.py +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/__init__.py +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/_interval_tree.py +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/mode.py +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/python_depends.json +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/dependency_links.txt +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/entry_points.txt +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/top_level.txt +0 -0
- {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_benchmarks.py +0 -0
|
@@ -1,20 +1,22 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kernels
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.14.0.dev0
|
|
4
4
|
Summary: Download compute kernels
|
|
5
5
|
Author-email: Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>
|
|
6
6
|
License: Apache-2.0
|
|
7
|
-
Requires-Python: >=3.
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
8
|
Description-Content-Type: text/markdown
|
|
9
|
-
Requires-Dist:
|
|
9
|
+
Requires-Dist: huggingface-hub>=1.10.0
|
|
10
10
|
Requires-Dist: packaging>=20.0
|
|
11
11
|
Requires-Dist: pyyaml>=6
|
|
12
12
|
Requires-Dist: tomli>=2.0; python_version < "3.11"
|
|
13
|
+
Requires-Dist: tomlkit>=0.13.3
|
|
13
14
|
Provides-Extra: abi-check
|
|
14
|
-
Requires-Dist: kernel-abi-check<0.
|
|
15
|
+
Requires-Dist: kernel-abi-check<0.7.0,>=0.6.2; extra == "abi-check"
|
|
15
16
|
Provides-Extra: benchmark
|
|
17
|
+
Requires-Dist: matplotlib>=3.7.0; extra == "benchmark"
|
|
16
18
|
Requires-Dist: numpy>=2.0.2; extra == "benchmark"
|
|
17
|
-
Requires-Dist:
|
|
19
|
+
Requires-Dist: tabulate>=0.9.0; extra == "benchmark"
|
|
18
20
|
Requires-Dist: torch; extra == "benchmark"
|
|
19
21
|
Provides-Extra: torch
|
|
20
22
|
Requires-Dist: torch; extra == "torch"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "kernels"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.14.0.dev0"
|
|
4
4
|
description = "Download compute kernels"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Daniel de Kok", email = "daniel@huggingface.co" },
|
|
@@ -8,12 +8,13 @@ authors = [
|
|
|
8
8
|
]
|
|
9
9
|
license = { text = "Apache-2.0" }
|
|
10
10
|
readme = "README.md"
|
|
11
|
-
requires-python = ">= 3.
|
|
11
|
+
requires-python = ">= 3.10"
|
|
12
12
|
dependencies = [
|
|
13
|
-
"
|
|
13
|
+
"huggingface-hub>=1.10.0",
|
|
14
14
|
"packaging>=20.0",
|
|
15
15
|
"pyyaml>=6",
|
|
16
16
|
"tomli>=2.0; python_version<'3.11'",
|
|
17
|
+
"tomlkit>=0.13.3",
|
|
17
18
|
]
|
|
18
19
|
|
|
19
20
|
[build-system]
|
|
@@ -28,15 +29,17 @@ dev = [
|
|
|
28
29
|
# Whatever version is compatible with pytest.
|
|
29
30
|
"pytest-benchmark",
|
|
30
31
|
"torch>=2.5",
|
|
32
|
+
"apache-tvm-ffi>=0.1.9,<0.2.0",
|
|
31
33
|
"types-pyyaml",
|
|
32
|
-
"types-
|
|
34
|
+
"types-tabulate",
|
|
33
35
|
]
|
|
34
36
|
|
|
35
37
|
[project.optional-dependencies]
|
|
36
|
-
abi-check = ["kernel-abi-check>=0.
|
|
38
|
+
abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"]
|
|
37
39
|
benchmark = [
|
|
40
|
+
"matplotlib>=3.7.0",
|
|
38
41
|
"numpy>=2.0.2",
|
|
39
|
-
"
|
|
42
|
+
"tabulate>=0.9.0",
|
|
40
43
|
"torch",
|
|
41
44
|
]
|
|
42
45
|
torch = ["torch"]
|
|
@@ -53,11 +56,10 @@ kernels = "kernels.cli:main"
|
|
|
53
56
|
[tool.setuptools.package-data]
|
|
54
57
|
kernels = ["python_depends.json"]
|
|
55
58
|
|
|
56
|
-
[tool.isort]
|
|
57
|
-
profile = "black"
|
|
58
|
-
line_length = 119
|
|
59
|
-
|
|
60
59
|
[tool.ruff]
|
|
60
|
+
# If the version is changed, apply the change in the Nix overlay
|
|
61
|
+
# as well.
|
|
62
|
+
required-version = "==0.15.10"
|
|
61
63
|
exclude = [
|
|
62
64
|
".eggs",
|
|
63
65
|
".git",
|
|
@@ -82,4 +84,6 @@ line-length = 119
|
|
|
82
84
|
# Ignored rules:
|
|
83
85
|
# "E501" -> line length violation
|
|
84
86
|
lint.ignore = ["E501"]
|
|
85
|
-
lint.select = ["E", "F", "W"]
|
|
87
|
+
lint.select = ["E", "F", "I", "W"]
|
|
88
|
+
|
|
89
|
+
[tool.ruff.format]
|
|
@@ -2,6 +2,8 @@ import importlib.metadata
|
|
|
2
2
|
|
|
3
3
|
__version__ = importlib.metadata.version("kernels")
|
|
4
4
|
|
|
5
|
+
from kernels._windows import _add_additional_dll_paths
|
|
6
|
+
from kernels.benchmark import Benchmark
|
|
5
7
|
from kernels.layer import (
|
|
6
8
|
CUDAProperties,
|
|
7
9
|
Device,
|
|
@@ -21,16 +23,13 @@ from kernels.layer import (
|
|
|
21
23
|
)
|
|
22
24
|
from kernels.utils import (
|
|
23
25
|
get_kernel,
|
|
26
|
+
get_loaded_kernels,
|
|
24
27
|
get_local_kernel,
|
|
25
28
|
get_locked_kernel,
|
|
26
29
|
has_kernel,
|
|
27
30
|
install_kernel,
|
|
28
31
|
load_kernel,
|
|
29
32
|
)
|
|
30
|
-
from kernels.benchmark import Benchmark
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
from kernels._windows import _add_additional_dll_paths
|
|
34
33
|
|
|
35
34
|
_add_additional_dll_paths()
|
|
36
35
|
|
|
@@ -47,6 +46,7 @@ __all__ = [
|
|
|
47
46
|
"LockedLayerRepository",
|
|
48
47
|
"Mode",
|
|
49
48
|
"get_kernel",
|
|
49
|
+
"get_loaded_kernels",
|
|
50
50
|
"get_local_kernel",
|
|
51
51
|
"get_locked_kernel",
|
|
52
52
|
"has_kernel",
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
from huggingface_hub.hf_api import GitRefInfo
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]:
|
|
10
|
+
"""Get kernel versions that are available in the repository."""
|
|
11
|
+
from kernels.utils import _get_hf_api
|
|
12
|
+
|
|
13
|
+
refs = _get_hf_api().list_repo_refs(repo_id=repo_id, repo_type="kernel")
|
|
14
|
+
|
|
15
|
+
versions = {}
|
|
16
|
+
for branch in refs.branches:
|
|
17
|
+
if not branch.name.startswith("v"):
|
|
18
|
+
continue
|
|
19
|
+
try:
|
|
20
|
+
versions[int(branch.name[1:])] = branch
|
|
21
|
+
except ValueError:
|
|
22
|
+
continue
|
|
23
|
+
|
|
24
|
+
return versions
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def resolve_version_spec_as_ref(repo_id: str, version_spec: int) -> GitRefInfo:
|
|
28
|
+
"""
|
|
29
|
+
Get the ref for a kernel with the given version.
|
|
30
|
+
"""
|
|
31
|
+
versions = _get_available_versions(repo_id)
|
|
32
|
+
|
|
33
|
+
ref = versions.get(version_spec, None)
|
|
34
|
+
if ref is None:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"Version {version_spec} not found, available versions: {', '.join(sorted(str(v) for v in versions.keys()))}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
latest_version = max(versions.keys())
|
|
40
|
+
if version_spec < latest_version:
|
|
41
|
+
logger.warning(
|
|
42
|
+
"You are using version %d of '%s', but version %d is available.",
|
|
43
|
+
version_spec,
|
|
44
|
+
repo_id,
|
|
45
|
+
latest_version,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return ref
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def select_revision_or_version(
|
|
52
|
+
repo_id: str,
|
|
53
|
+
*,
|
|
54
|
+
revision: str | None,
|
|
55
|
+
version: int | None,
|
|
56
|
+
) -> str:
|
|
57
|
+
if revision is not None and version is not None:
|
|
58
|
+
raise ValueError("Only one of `revision` or `version` must be specified.")
|
|
59
|
+
|
|
60
|
+
if revision is not None:
|
|
61
|
+
return revision
|
|
62
|
+
elif version is not None:
|
|
63
|
+
return resolve_version_spec_as_ref(repo_id, version).target_commit
|
|
64
|
+
|
|
65
|
+
warnings.warn(
|
|
66
|
+
"Future versions of `kernels` (>=0.15) will require specifying a kernel version or revision. "
|
|
67
|
+
"See: https://huggingface.co/docs/kernels/migration",
|
|
68
|
+
FutureWarning,
|
|
69
|
+
stacklevel=2,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return "main"
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
import ctypes.util
|
|
3
|
+
import re
|
|
4
|
+
import warnings
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import ClassVar, Optional, Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
from huggingface_hub.dataclasses import strict
|
|
9
|
+
from packaging.version import Version
|
|
10
|
+
|
|
11
|
+
from kernels.compat import has_torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@runtime_checkable
|
|
15
|
+
class Backend(Protocol):
|
|
16
|
+
@property
|
|
17
|
+
def name(self) -> str:
|
|
18
|
+
"""
|
|
19
|
+
Short name of the backend, e.g. "cuda", "rocm", "cpu", etc.
|
|
20
|
+
"""
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def variant_str(self) -> str:
|
|
25
|
+
"""
|
|
26
|
+
The name of the backend as used in a build variant, e.g. `cu128`
|
|
27
|
+
for CUDA 12.8.
|
|
28
|
+
"""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(unsafe_hash=True)
|
|
33
|
+
class CANN:
|
|
34
|
+
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cann(\d+)(\d+)")
|
|
35
|
+
|
|
36
|
+
version: Version
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def name(self) -> str:
|
|
40
|
+
return "cann"
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def variant_str(self) -> str:
|
|
44
|
+
return f"cann{self.version.major}{self.version.minor}"
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def parse(s: str) -> "CANN":
|
|
48
|
+
m = CANN._VARIANT_REGEX.fullmatch(s)
|
|
49
|
+
if not m:
|
|
50
|
+
raise ValueError(f"Invalid CANN variant string: {s!r}")
|
|
51
|
+
return CANN(version=Version(f"{m.group(1)}.{m.group(2)}"))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@strict
|
|
55
|
+
@dataclass(unsafe_hash=True)
|
|
56
|
+
class CPU:
|
|
57
|
+
@property
|
|
58
|
+
def name(self) -> str:
|
|
59
|
+
return "cpu"
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def variant_str(self) -> str:
|
|
63
|
+
return "cpu"
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def parse(s: str) -> "CPU":
|
|
67
|
+
if s != "cpu":
|
|
68
|
+
raise ValueError(f"Invalid CPU variant string: {s!r}")
|
|
69
|
+
return CPU()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass(unsafe_hash=True)
|
|
73
|
+
class CUDA:
|
|
74
|
+
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cu(\d+)(\d+)")
|
|
75
|
+
|
|
76
|
+
version: Version
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def name(self) -> str:
|
|
80
|
+
return "cuda"
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def variant_str(self) -> str:
|
|
84
|
+
return f"cu{self.version.major}{self.version.minor}"
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def parse(s: str) -> "CUDA":
|
|
88
|
+
m = CUDA._VARIANT_REGEX.fullmatch(s)
|
|
89
|
+
if not m:
|
|
90
|
+
raise ValueError(f"Invalid CUDA variant string: {s!r}")
|
|
91
|
+
return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}"))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@strict
|
|
95
|
+
@dataclass(unsafe_hash=True)
|
|
96
|
+
class Metal:
|
|
97
|
+
@property
|
|
98
|
+
def name(self) -> str:
|
|
99
|
+
return "metal"
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def variant_str(self) -> str:
|
|
103
|
+
return "metal"
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def parse(s: str) -> "Metal":
|
|
107
|
+
if s != "metal":
|
|
108
|
+
raise ValueError(f"Invalid Metal variant string: {s!r}")
|
|
109
|
+
return Metal()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@strict
|
|
113
|
+
@dataclass(unsafe_hash=True)
|
|
114
|
+
class Neuron:
|
|
115
|
+
@property
|
|
116
|
+
def name(self) -> str:
|
|
117
|
+
return "neuron"
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def variant_str(self) -> str:
|
|
121
|
+
return "neuron"
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def parse(s: str) -> "Neuron":
|
|
125
|
+
if s != "neuron":
|
|
126
|
+
raise ValueError(f"Invalid Neuron variant string: {s!r}")
|
|
127
|
+
return Neuron()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass(unsafe_hash=True)
|
|
131
|
+
class ROCm:
|
|
132
|
+
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"rocm(\d+)(\d+)")
|
|
133
|
+
|
|
134
|
+
version: Version
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def name(self) -> str:
|
|
138
|
+
return "rocm"
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def variant_str(self) -> str:
|
|
142
|
+
return f"rocm{self.version.major}{self.version.minor}"
|
|
143
|
+
|
|
144
|
+
@staticmethod
|
|
145
|
+
def parse(s: str) -> "ROCm":
|
|
146
|
+
m = ROCm._VARIANT_REGEX.fullmatch(s)
|
|
147
|
+
if not m:
|
|
148
|
+
raise ValueError(f"Invalid ROCm variant string: {s!r}")
|
|
149
|
+
return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}"))
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@dataclass(unsafe_hash=True)
|
|
153
|
+
class XPU:
|
|
154
|
+
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"xpu(\d+)(\d+)")
|
|
155
|
+
|
|
156
|
+
version: Version
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def name(self) -> str:
|
|
160
|
+
return "xpu"
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def variant_str(self) -> str:
|
|
164
|
+
return f"xpu{self.version.major}{self.version.minor}"
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def parse(s: str) -> "XPU":
|
|
168
|
+
m = XPU._VARIANT_REGEX.fullmatch(s)
|
|
169
|
+
if not m:
|
|
170
|
+
raise ValueError(f"Invalid XPU variant string: {s!r}")
|
|
171
|
+
return XPU(version=Version(f"{m.group(1)}.{m.group(2)}"))
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def parse_backend(s: str) -> Backend:
|
|
175
|
+
"""Parse a backend variant string (e.g. 'cu128', 'rocm61', 'cpu') into a Backend."""
|
|
176
|
+
if s == "cpu":
|
|
177
|
+
return CPU.parse(s)
|
|
178
|
+
elif s == "metal":
|
|
179
|
+
return Metal.parse(s)
|
|
180
|
+
elif s == "neuron":
|
|
181
|
+
return Neuron.parse(s)
|
|
182
|
+
elif s.startswith("cu"):
|
|
183
|
+
return CUDA.parse(s)
|
|
184
|
+
elif s.startswith("rocm"):
|
|
185
|
+
return ROCm.parse(s)
|
|
186
|
+
elif s.startswith("xpu"):
|
|
187
|
+
return XPU.parse(s)
|
|
188
|
+
elif s.startswith("cann"):
|
|
189
|
+
return CANN.parse(s)
|
|
190
|
+
else:
|
|
191
|
+
raise ValueError(f"Unknown backend variant string: {s!r}")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _backend() -> Backend:
|
|
195
|
+
if has_torch:
|
|
196
|
+
import torch
|
|
197
|
+
|
|
198
|
+
if hasattr(torch, "neuron"):
|
|
199
|
+
# Needs to be sorted before specific Torch builds, since Neuron
|
|
200
|
+
# extension can be loaded into e.g. CUDA Torch builds.
|
|
201
|
+
return Neuron()
|
|
202
|
+
elif torch.version.cuda is not None:
|
|
203
|
+
cuda_version = Version(torch.version.cuda)
|
|
204
|
+
return CUDA(version=cuda_version)
|
|
205
|
+
elif torch.version.hip is not None:
|
|
206
|
+
rocm_version = Version(torch.version.hip.split("-")[0])
|
|
207
|
+
return ROCm(version=rocm_version)
|
|
208
|
+
elif torch.backends.mps.is_available():
|
|
209
|
+
return Metal()
|
|
210
|
+
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
|
|
211
|
+
version = f"{torch.version.xpu[0:4]}.{torch.version.xpu[5:6]}"
|
|
212
|
+
return XPU(version=Version(version))
|
|
213
|
+
elif _get_torch_privateuse_backend_name() == "npu":
|
|
214
|
+
from torch_npu.utils.collect_env import get_cann_version # type: ignore[import-not-found]
|
|
215
|
+
|
|
216
|
+
cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
|
|
217
|
+
return CANN(version=Version(f"{cann_major}.{cann_minor}"))
|
|
218
|
+
else:
|
|
219
|
+
return CPU()
|
|
220
|
+
else:
|
|
221
|
+
cuda = _get_cuda()
|
|
222
|
+
if cuda is not None:
|
|
223
|
+
return cuda
|
|
224
|
+
|
|
225
|
+
return CPU()
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def _get_torch_privateuse_backend_name() -> str | None:
|
|
229
|
+
import torch
|
|
230
|
+
|
|
231
|
+
if hasattr(torch._C, "_get_privateuse1_backend_name"):
|
|
232
|
+
return torch._C._get_privateuse1_backend_name()
|
|
233
|
+
return None
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _select_backend(backend: str | None) -> Backend:
|
|
237
|
+
if backend is None:
|
|
238
|
+
return _backend()
|
|
239
|
+
|
|
240
|
+
supported = _supported_backends()
|
|
241
|
+
if backend in supported:
|
|
242
|
+
return supported[backend]
|
|
243
|
+
|
|
244
|
+
raise ValueError(f"Invalid backend '{backend}', system supported backends: {', '.join(sorted(supported.keys()))}")
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _supported_backends() -> dict[str, Backend]:
|
|
248
|
+
backend = _backend()
|
|
249
|
+
return {"cpu": CPU(), backend.name: backend}
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _get_cuda() -> Optional[CUDA]:
|
|
253
|
+
"""
|
|
254
|
+
Get CUDA runtime library information.
|
|
255
|
+
"""
|
|
256
|
+
lib_name = ctypes.util.find_library("cudart")
|
|
257
|
+
if lib_name is None:
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
libcudart = ctypes.CDLL(lib_name)
|
|
262
|
+
except OSError:
|
|
263
|
+
return None
|
|
264
|
+
|
|
265
|
+
runtime_version = ctypes.c_int(0)
|
|
266
|
+
result = libcudart.cudaRuntimeGetVersion(ctypes.byref(runtime_version))
|
|
267
|
+
if result != 0:
|
|
268
|
+
warnings.warn("System has CUDA runtime library, but cannot get runtime version.")
|
|
269
|
+
return None
|
|
270
|
+
|
|
271
|
+
# cudaRuntimeGetVersion encodes the version as (major * 1000 + minor * 10).
|
|
272
|
+
version_int = runtime_version.value
|
|
273
|
+
major = version_int // 1000
|
|
274
|
+
minor = (version_int % 1000) // 10
|
|
275
|
+
|
|
276
|
+
return CUDA(version=Version(f"{major}.{minor}"))
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Benchmark:
|
|
5
|
+
"""Base class for kernel benchmarks.
|
|
6
|
+
|
|
7
|
+
Subclass this to create a benchmark script with automatic timing,
|
|
8
|
+
verification, and reproducibility support. The kernel is loaded
|
|
9
|
+
automatically from the repo_id specified in the CLI command.
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
class MyBenchmark(Benchmark):
|
|
13
|
+
seed = 42
|
|
14
|
+
|
|
15
|
+
def setup(self):
|
|
16
|
+
self.x = torch.randn(128, 1024, device=self.device, dtype=torch.float16)
|
|
17
|
+
self.out = torch.empty(128, 512, device=self.device, dtype=torch.float16)
|
|
18
|
+
|
|
19
|
+
def benchmark_silu(self):
|
|
20
|
+
self.kernel.silu_and_mul(self.out, self.x)
|
|
21
|
+
|
|
22
|
+
def verify_silu(self) -> torch.Tensor:
|
|
23
|
+
# Return reference tensor; runner compares with self.out
|
|
24
|
+
return torch.nn.functional.silu(self.x[..., :512]) * self.x[..., 512:]
|
|
25
|
+
|
|
26
|
+
Run with: kernels benchmark <repo_id>
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
seed: int | None = None # Optional: seed for reproducibility
|
|
30
|
+
device: str = "cpu" # Set automatically by runner
|
|
31
|
+
|
|
32
|
+
def __init__(self) -> None:
|
|
33
|
+
self.kernel: Any = None
|
|
34
|
+
self.out: Any = None # Output tensor, set by setup methods
|
|
35
|
+
|
|
36
|
+
def setup(self) -> None:
|
|
37
|
+
"""Override to set up tensors as instance attributes."""
|
|
38
|
+
pass
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from kernels.benchmark import Benchmark
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SiluAndMulBenchmark(Benchmark):
|
|
8
|
+
seed: int = 42
|
|
9
|
+
|
|
10
|
+
# Workload: small
|
|
11
|
+
def setup_small(self):
|
|
12
|
+
self.x = torch.randn(8, 1024, 2048, device=self.device, dtype=torch.float16)
|
|
13
|
+
self.out = torch.empty(8, 1024, 1024, device=self.device, dtype=torch.float16)
|
|
14
|
+
|
|
15
|
+
def benchmark_small(self):
|
|
16
|
+
self.kernel.silu_and_mul(self.out, self.x)
|
|
17
|
+
|
|
18
|
+
def verify_small(self) -> torch.Tensor:
|
|
19
|
+
d = self.x.shape[-1] // 2
|
|
20
|
+
return F.silu(self.x[..., :d]) * self.x[..., d:]
|
|
21
|
+
|
|
22
|
+
# Workload: medium
|
|
23
|
+
def setup_medium(self):
|
|
24
|
+
self.x = torch.randn(8, 2048, 4096, device=self.device, dtype=torch.float16)
|
|
25
|
+
self.out = torch.empty(8, 2048, 2048, device=self.device, dtype=torch.float16)
|
|
26
|
+
|
|
27
|
+
def benchmark_medium(self):
|
|
28
|
+
self.kernel.silu_and_mul(self.out, self.x)
|
|
29
|
+
|
|
30
|
+
def verify_medium(self) -> torch.Tensor:
|
|
31
|
+
d = self.x.shape[-1] // 2
|
|
32
|
+
return F.silu(self.x[..., :d]) * self.x[..., d:]
|
|
33
|
+
|
|
34
|
+
# Workload: large
|
|
35
|
+
def setup_large(self):
|
|
36
|
+
self.x = torch.randn(8, 4096, 8192, device=self.device, dtype=torch.float16)
|
|
37
|
+
self.out = torch.empty(8, 4096, 4096, device=self.device, dtype=torch.float16)
|
|
38
|
+
|
|
39
|
+
def benchmark_large(self):
|
|
40
|
+
self.kernel.silu_and_mul(self.out, self.x)
|
|
41
|
+
self.kernel.silu_and_mul(self.out, self.x)
|
|
42
|
+
|
|
43
|
+
def verify_large(self) -> torch.Tensor:
|
|
44
|
+
d = self.x.shape[-1] // 2
|
|
45
|
+
return F.silu(self.x[..., :d]) * self.x[..., d:]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class GeluAndMulBenchmark(Benchmark):
|
|
49
|
+
seed: int = 42
|
|
50
|
+
|
|
51
|
+
# Workload: small
|
|
52
|
+
def setup_small(self):
|
|
53
|
+
self.x = torch.randn(8, 1024, 2048, device=self.device, dtype=torch.float16)
|
|
54
|
+
self.out = torch.empty(8, 1024, 1024, device=self.device, dtype=torch.float16)
|
|
55
|
+
|
|
56
|
+
def benchmark_small(self):
|
|
57
|
+
self.kernel.gelu_and_mul(self.out, self.x)
|
|
58
|
+
|
|
59
|
+
def verify_small(self) -> torch.Tensor:
|
|
60
|
+
d = self.x.shape[-1] // 2
|
|
61
|
+
return F.gelu(self.x[..., :d]) * self.x[..., d:]
|
|
62
|
+
|
|
63
|
+
# Workload: medium
|
|
64
|
+
def setup_medium(self):
|
|
65
|
+
self.x = torch.randn(8, 2048, 4096, device=self.device, dtype=torch.float16)
|
|
66
|
+
self.out = torch.empty(8, 2048, 2048, device=self.device, dtype=torch.float16)
|
|
67
|
+
|
|
68
|
+
def benchmark_medium(self):
|
|
69
|
+
self.kernel.gelu_and_mul(self.out, self.x)
|
|
70
|
+
|
|
71
|
+
def verify_medium(self) -> torch.Tensor:
|
|
72
|
+
d = self.x.shape[-1] // 2
|
|
73
|
+
return F.gelu(self.x[..., :d]) * self.x[..., d:]
|
|
74
|
+
|
|
75
|
+
# Workload: large
|
|
76
|
+
def setup_large(self):
|
|
77
|
+
self.x = torch.randn(8, 4096, 8192, device=self.device, dtype=torch.float16)
|
|
78
|
+
self.out = torch.empty(8, 4096, 4096, device=self.device, dtype=torch.float16)
|
|
79
|
+
|
|
80
|
+
def benchmark_large(self):
|
|
81
|
+
self.kernel.gelu_and_mul(self.out, self.x)
|
|
82
|
+
|
|
83
|
+
def verify_large(self) -> torch.Tensor:
|
|
84
|
+
d = self.x.shape[-1] // 2
|
|
85
|
+
return F.gelu(self.x[..., :d]) * self.x[..., d:]
|