kernels 0.12.3__tar.gz → 0.14.0.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {kernels-0.12.3 → kernels-0.14.0.dev0}/PKG-INFO +7 -5
  2. {kernels-0.12.3 → kernels-0.14.0.dev0}/pyproject.toml +15 -11
  3. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/__init__.py +4 -4
  4. kernels-0.14.0.dev0/src/kernels/_versions.py +72 -0
  5. kernels-0.14.0.dev0/src/kernels/backends.py +276 -0
  6. kernels-0.14.0.dev0/src/kernels/benchmark.py +38 -0
  7. kernels-0.14.0.dev0/src/kernels/benchmarks/activation.py +85 -0
  8. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/benchmarks/attention.py +11 -33
  9. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/benchmarks/layer_norm.py +3 -9
  10. kernels-0.12.3/src/kernels/cli.py → kernels-0.14.0.dev0/src/kernels/cli/__init__.py +32 -85
  11. {kernels-0.12.3/src/kernels → kernels-0.14.0.dev0/src/kernels/cli}/benchmark.py +217 -156
  12. kernels-0.14.0.dev0/src/kernels/cli/benchmark_graphics.py +740 -0
  13. {kernels-0.12.3/src/kernels → kernels-0.14.0.dev0/src/kernels/cli}/check.py +15 -8
  14. kernels-0.14.0.dev0/src/kernels/cli/versions.py +30 -0
  15. kernels-0.14.0.dev0/src/kernels/compat.py +14 -0
  16. kernels-0.14.0.dev0/src/kernels/deps.py +100 -0
  17. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/device.py +19 -8
  18. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/func.py +5 -10
  19. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/globals.py +1 -3
  20. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/kernelize.py +6 -15
  21. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/layer.py +16 -43
  22. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/repos.py +14 -33
  23. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/lockfile.py +35 -19
  24. kernels-0.14.0.dev0/src/kernels/metadata.py +44 -0
  25. kernels-0.14.0.dev0/src/kernels/status.py +81 -0
  26. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/utils.py +258 -209
  27. kernels-0.14.0.dev0/src/kernels/variants.py +452 -0
  28. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/PKG-INFO +7 -5
  29. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/SOURCES.txt +12 -8
  30. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/requires.txt +5 -3
  31. {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_basic.py +97 -9
  32. {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_deps.py +1 -3
  33. {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_doctest.py +1 -3
  34. {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_func.py +1 -1
  35. {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_interval_tree.py +3 -9
  36. kernels-0.14.0.dev0/tests/test_kernel_locking.py +100 -0
  37. {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_layer.py +63 -157
  38. kernels-0.14.0.dev0/tests/test_status.py +94 -0
  39. kernels-0.14.0.dev0/tests/test_tvm_ffi.py +47 -0
  40. kernels-0.14.0.dev0/tests/test_user_agent.py +71 -0
  41. kernels-0.14.0.dev0/tests/test_variants.py +383 -0
  42. kernels-0.12.3/src/kernels/_vendored/convert_rst_to_mdx.py +0 -751
  43. kernels-0.12.3/src/kernels/_versions.py +0 -95
  44. kernels-0.12.3/src/kernels/benchmarks/activation.py +0 -44
  45. kernels-0.12.3/src/kernels/compat.py +0 -8
  46. kernels-0.12.3/src/kernels/deps.py +0 -59
  47. kernels-0.12.3/src/kernels/doc.py +0 -242
  48. kernels-0.12.3/src/kernels/metadata.py +0 -37
  49. kernels-0.12.3/src/kernels/upload.py +0 -82
  50. kernels-0.12.3/src/kernels/variants.py +0 -3
  51. kernels-0.12.3/src/kernels/versions_cli.py +0 -38
  52. kernels-0.12.3/tests/test_kernel_locking.py +0 -208
  53. kernels-0.12.3/tests/test_kernel_upload.py +0 -121
  54. {kernels-0.12.3 → kernels-0.14.0.dev0}/README.md +0 -0
  55. {kernels-0.12.3 → kernels-0.14.0.dev0}/setup.cfg +0 -0
  56. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/_system.py +0 -0
  57. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/_windows.py +0 -0
  58. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/benchmarks/__init__.py +0 -0
  59. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/__init__.py +0 -0
  60. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/_interval_tree.py +0 -0
  61. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/layer/mode.py +0 -0
  62. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels/python_depends.json +0 -0
  63. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/dependency_links.txt +0 -0
  64. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/entry_points.txt +0 -0
  65. {kernels-0.12.3 → kernels-0.14.0.dev0}/src/kernels.egg-info/top_level.txt +0 -0
  66. {kernels-0.12.3 → kernels-0.14.0.dev0}/tests/test_benchmarks.py +0 -0
@@ -1,20 +1,22 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kernels
3
- Version: 0.12.3
3
+ Version: 0.14.0.dev0
4
4
  Summary: Download compute kernels
5
5
  Author-email: Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>
6
6
  License: Apache-2.0
7
- Requires-Python: >=3.9
7
+ Requires-Python: >=3.10
8
8
  Description-Content-Type: text/markdown
9
- Requires-Dist: huggingface_hub<2.0,>=0.26.0
9
+ Requires-Dist: huggingface-hub>=1.10.0
10
10
  Requires-Dist: packaging>=20.0
11
11
  Requires-Dist: pyyaml>=6
12
12
  Requires-Dist: tomli>=2.0; python_version < "3.11"
13
+ Requires-Dist: tomlkit>=0.13.3
13
14
  Provides-Extra: abi-check
14
- Requires-Dist: kernel-abi-check<0.13.0,>=0.12.0; extra == "abi-check"
15
+ Requires-Dist: kernel-abi-check<0.7.0,>=0.6.2; extra == "abi-check"
15
16
  Provides-Extra: benchmark
17
+ Requires-Dist: matplotlib>=3.7.0; extra == "benchmark"
16
18
  Requires-Dist: numpy>=2.0.2; extra == "benchmark"
17
- Requires-Dist: requests>=2.32.5; extra == "benchmark"
19
+ Requires-Dist: tabulate>=0.9.0; extra == "benchmark"
18
20
  Requires-Dist: torch; extra == "benchmark"
19
21
  Provides-Extra: torch
20
22
  Requires-Dist: torch; extra == "torch"
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kernels"
3
- version = "0.12.3"
3
+ version = "0.14.0.dev0"
4
4
  description = "Download compute kernels"
5
5
  authors = [
6
6
  { name = "Daniel de Kok", email = "daniel@huggingface.co" },
@@ -8,12 +8,13 @@ authors = [
8
8
  ]
9
9
  license = { text = "Apache-2.0" }
10
10
  readme = "README.md"
11
- requires-python = ">= 3.9"
11
+ requires-python = ">= 3.10"
12
12
  dependencies = [
13
- "huggingface_hub>=0.26.0,<2.0",
13
+ "huggingface-hub>=1.10.0",
14
14
  "packaging>=20.0",
15
15
  "pyyaml>=6",
16
16
  "tomli>=2.0; python_version<'3.11'",
17
+ "tomlkit>=0.13.3",
17
18
  ]
18
19
 
19
20
  [build-system]
@@ -28,15 +29,17 @@ dev = [
28
29
  # Whatever version is compatible with pytest.
29
30
  "pytest-benchmark",
30
31
  "torch>=2.5",
32
+ "apache-tvm-ffi>=0.1.9,<0.2.0",
31
33
  "types-pyyaml",
32
- "types-requests",
34
+ "types-tabulate",
33
35
  ]
34
36
 
35
37
  [project.optional-dependencies]
36
- abi-check = ["kernel-abi-check>=0.12.0,<0.13.0"]
38
+ abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"]
37
39
  benchmark = [
40
+ "matplotlib>=3.7.0",
38
41
  "numpy>=2.0.2",
39
- "requests>=2.32.5",
42
+ "tabulate>=0.9.0",
40
43
  "torch",
41
44
  ]
42
45
  torch = ["torch"]
@@ -53,11 +56,10 @@ kernels = "kernels.cli:main"
53
56
  [tool.setuptools.package-data]
54
57
  kernels = ["python_depends.json"]
55
58
 
56
- [tool.isort]
57
- profile = "black"
58
- line_length = 119
59
-
60
59
  [tool.ruff]
60
+ # If the version is changed, apply the change in the Nix overlay
61
+ # as well.
62
+ required-version = "==0.15.10"
61
63
  exclude = [
62
64
  ".eggs",
63
65
  ".git",
@@ -82,4 +84,6 @@ line-length = 119
82
84
  # Ignored rules:
83
85
  # "E501" -> line length violation
84
86
  lint.ignore = ["E501"]
85
- lint.select = ["E", "F", "W"]
87
+ lint.select = ["E", "F", "I", "W"]
88
+
89
+ [tool.ruff.format]
@@ -2,6 +2,8 @@ import importlib.metadata
2
2
 
3
3
  __version__ = importlib.metadata.version("kernels")
4
4
 
5
+ from kernels._windows import _add_additional_dll_paths
6
+ from kernels.benchmark import Benchmark
5
7
  from kernels.layer import (
6
8
  CUDAProperties,
7
9
  Device,
@@ -21,16 +23,13 @@ from kernels.layer import (
21
23
  )
22
24
  from kernels.utils import (
23
25
  get_kernel,
26
+ get_loaded_kernels,
24
27
  get_local_kernel,
25
28
  get_locked_kernel,
26
29
  has_kernel,
27
30
  install_kernel,
28
31
  load_kernel,
29
32
  )
30
- from kernels.benchmark import Benchmark
31
-
32
-
33
- from kernels._windows import _add_additional_dll_paths
34
33
 
35
34
  _add_additional_dll_paths()
36
35
 
@@ -47,6 +46,7 @@ __all__ = [
47
46
  "LockedLayerRepository",
48
47
  "Mode",
49
48
  "get_kernel",
49
+ "get_loaded_kernels",
50
50
  "get_local_kernel",
51
51
  "get_locked_kernel",
52
52
  "has_kernel",
@@ -0,0 +1,72 @@
1
+ import logging
2
+ import warnings
3
+
4
+ from huggingface_hub.hf_api import GitRefInfo
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]:
10
+ """Get kernel versions that are available in the repository."""
11
+ from kernels.utils import _get_hf_api
12
+
13
+ refs = _get_hf_api().list_repo_refs(repo_id=repo_id, repo_type="kernel")
14
+
15
+ versions = {}
16
+ for branch in refs.branches:
17
+ if not branch.name.startswith("v"):
18
+ continue
19
+ try:
20
+ versions[int(branch.name[1:])] = branch
21
+ except ValueError:
22
+ continue
23
+
24
+ return versions
25
+
26
+
27
+ def resolve_version_spec_as_ref(repo_id: str, version_spec: int) -> GitRefInfo:
28
+ """
29
+ Get the ref for a kernel with the given version.
30
+ """
31
+ versions = _get_available_versions(repo_id)
32
+
33
+ ref = versions.get(version_spec, None)
34
+ if ref is None:
35
+ raise ValueError(
36
+ f"Version {version_spec} not found, available versions: {', '.join(sorted(str(v) for v in versions.keys()))}"
37
+ )
38
+
39
+ latest_version = max(versions.keys())
40
+ if version_spec < latest_version:
41
+ logger.warning(
42
+ "You are using version %d of '%s', but version %d is available.",
43
+ version_spec,
44
+ repo_id,
45
+ latest_version,
46
+ )
47
+
48
+ return ref
49
+
50
+
51
+ def select_revision_or_version(
52
+ repo_id: str,
53
+ *,
54
+ revision: str | None,
55
+ version: int | None,
56
+ ) -> str:
57
+ if revision is not None and version is not None:
58
+ raise ValueError("Only one of `revision` or `version` must be specified.")
59
+
60
+ if revision is not None:
61
+ return revision
62
+ elif version is not None:
63
+ return resolve_version_spec_as_ref(repo_id, version).target_commit
64
+
65
+ warnings.warn(
66
+ "Future versions of `kernels` (>=0.15) will require specifying a kernel version or revision. "
67
+ "See: https://huggingface.co/docs/kernels/migration",
68
+ FutureWarning,
69
+ stacklevel=2,
70
+ )
71
+
72
+ return "main"
@@ -0,0 +1,276 @@
1
+ import ctypes
2
+ import ctypes.util
3
+ import re
4
+ import warnings
5
+ from dataclasses import dataclass
6
+ from typing import ClassVar, Optional, Protocol, runtime_checkable
7
+
8
+ from huggingface_hub.dataclasses import strict
9
+ from packaging.version import Version
10
+
11
+ from kernels.compat import has_torch
12
+
13
+
14
+ @runtime_checkable
15
+ class Backend(Protocol):
16
+ @property
17
+ def name(self) -> str:
18
+ """
19
+ Short name of the backend, e.g. "cuda", "rocm", "cpu", etc.
20
+ """
21
+ ...
22
+
23
+ @property
24
+ def variant_str(self) -> str:
25
+ """
26
+ The name of the backend as used in a build variant, e.g. `cu128`
27
+ for CUDA 12.8.
28
+ """
29
+ ...
30
+
31
+
32
+ @dataclass(unsafe_hash=True)
33
+ class CANN:
34
+ _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cann(\d+)(\d+)")
35
+
36
+ version: Version
37
+
38
+ @property
39
+ def name(self) -> str:
40
+ return "cann"
41
+
42
+ @property
43
+ def variant_str(self) -> str:
44
+ return f"cann{self.version.major}{self.version.minor}"
45
+
46
+ @staticmethod
47
+ def parse(s: str) -> "CANN":
48
+ m = CANN._VARIANT_REGEX.fullmatch(s)
49
+ if not m:
50
+ raise ValueError(f"Invalid CANN variant string: {s!r}")
51
+ return CANN(version=Version(f"{m.group(1)}.{m.group(2)}"))
52
+
53
+
54
+ @strict
55
+ @dataclass(unsafe_hash=True)
56
+ class CPU:
57
+ @property
58
+ def name(self) -> str:
59
+ return "cpu"
60
+
61
+ @property
62
+ def variant_str(self) -> str:
63
+ return "cpu"
64
+
65
+ @staticmethod
66
+ def parse(s: str) -> "CPU":
67
+ if s != "cpu":
68
+ raise ValueError(f"Invalid CPU variant string: {s!r}")
69
+ return CPU()
70
+
71
+
72
+ @dataclass(unsafe_hash=True)
73
+ class CUDA:
74
+ _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cu(\d+)(\d+)")
75
+
76
+ version: Version
77
+
78
+ @property
79
+ def name(self) -> str:
80
+ return "cuda"
81
+
82
+ @property
83
+ def variant_str(self) -> str:
84
+ return f"cu{self.version.major}{self.version.minor}"
85
+
86
+ @staticmethod
87
+ def parse(s: str) -> "CUDA":
88
+ m = CUDA._VARIANT_REGEX.fullmatch(s)
89
+ if not m:
90
+ raise ValueError(f"Invalid CUDA variant string: {s!r}")
91
+ return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}"))
92
+
93
+
94
+ @strict
95
+ @dataclass(unsafe_hash=True)
96
+ class Metal:
97
+ @property
98
+ def name(self) -> str:
99
+ return "metal"
100
+
101
+ @property
102
+ def variant_str(self) -> str:
103
+ return "metal"
104
+
105
+ @staticmethod
106
+ def parse(s: str) -> "Metal":
107
+ if s != "metal":
108
+ raise ValueError(f"Invalid Metal variant string: {s!r}")
109
+ return Metal()
110
+
111
+
112
+ @strict
113
+ @dataclass(unsafe_hash=True)
114
+ class Neuron:
115
+ @property
116
+ def name(self) -> str:
117
+ return "neuron"
118
+
119
+ @property
120
+ def variant_str(self) -> str:
121
+ return "neuron"
122
+
123
+ @staticmethod
124
+ def parse(s: str) -> "Neuron":
125
+ if s != "neuron":
126
+ raise ValueError(f"Invalid Neuron variant string: {s!r}")
127
+ return Neuron()
128
+
129
+
130
+ @dataclass(unsafe_hash=True)
131
+ class ROCm:
132
+ _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"rocm(\d+)(\d+)")
133
+
134
+ version: Version
135
+
136
+ @property
137
+ def name(self) -> str:
138
+ return "rocm"
139
+
140
+ @property
141
+ def variant_str(self) -> str:
142
+ return f"rocm{self.version.major}{self.version.minor}"
143
+
144
+ @staticmethod
145
+ def parse(s: str) -> "ROCm":
146
+ m = ROCm._VARIANT_REGEX.fullmatch(s)
147
+ if not m:
148
+ raise ValueError(f"Invalid ROCm variant string: {s!r}")
149
+ return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}"))
150
+
151
+
152
+ @dataclass(unsafe_hash=True)
153
+ class XPU:
154
+ _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"xpu(\d+)(\d+)")
155
+
156
+ version: Version
157
+
158
+ @property
159
+ def name(self) -> str:
160
+ return "xpu"
161
+
162
+ @property
163
+ def variant_str(self) -> str:
164
+ return f"xpu{self.version.major}{self.version.minor}"
165
+
166
+ @staticmethod
167
+ def parse(s: str) -> "XPU":
168
+ m = XPU._VARIANT_REGEX.fullmatch(s)
169
+ if not m:
170
+ raise ValueError(f"Invalid XPU variant string: {s!r}")
171
+ return XPU(version=Version(f"{m.group(1)}.{m.group(2)}"))
172
+
173
+
174
+ def parse_backend(s: str) -> Backend:
175
+ """Parse a backend variant string (e.g. 'cu128', 'rocm61', 'cpu') into a Backend."""
176
+ if s == "cpu":
177
+ return CPU.parse(s)
178
+ elif s == "metal":
179
+ return Metal.parse(s)
180
+ elif s == "neuron":
181
+ return Neuron.parse(s)
182
+ elif s.startswith("cu"):
183
+ return CUDA.parse(s)
184
+ elif s.startswith("rocm"):
185
+ return ROCm.parse(s)
186
+ elif s.startswith("xpu"):
187
+ return XPU.parse(s)
188
+ elif s.startswith("cann"):
189
+ return CANN.parse(s)
190
+ else:
191
+ raise ValueError(f"Unknown backend variant string: {s!r}")
192
+
193
+
194
+ def _backend() -> Backend:
195
+ if has_torch:
196
+ import torch
197
+
198
+ if hasattr(torch, "neuron"):
199
+ # Needs to be sorted before specific Torch builds, since Neuron
200
+ # extension can be loaded into e.g. CUDA Torch builds.
201
+ return Neuron()
202
+ elif torch.version.cuda is not None:
203
+ cuda_version = Version(torch.version.cuda)
204
+ return CUDA(version=cuda_version)
205
+ elif torch.version.hip is not None:
206
+ rocm_version = Version(torch.version.hip.split("-")[0])
207
+ return ROCm(version=rocm_version)
208
+ elif torch.backends.mps.is_available():
209
+ return Metal()
210
+ elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
211
+ version = f"{torch.version.xpu[0:4]}.{torch.version.xpu[5:6]}"
212
+ return XPU(version=Version(version))
213
+ elif _get_torch_privateuse_backend_name() == "npu":
214
+ from torch_npu.utils.collect_env import get_cann_version # type: ignore[import-not-found]
215
+
216
+ cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
217
+ return CANN(version=Version(f"{cann_major}.{cann_minor}"))
218
+ else:
219
+ return CPU()
220
+ else:
221
+ cuda = _get_cuda()
222
+ if cuda is not None:
223
+ return cuda
224
+
225
+ return CPU()
226
+
227
+
228
+ def _get_torch_privateuse_backend_name() -> str | None:
229
+ import torch
230
+
231
+ if hasattr(torch._C, "_get_privateuse1_backend_name"):
232
+ return torch._C._get_privateuse1_backend_name()
233
+ return None
234
+
235
+
236
+ def _select_backend(backend: str | None) -> Backend:
237
+ if backend is None:
238
+ return _backend()
239
+
240
+ supported = _supported_backends()
241
+ if backend in supported:
242
+ return supported[backend]
243
+
244
+ raise ValueError(f"Invalid backend '{backend}', system supported backends: {', '.join(sorted(supported.keys()))}")
245
+
246
+
247
+ def _supported_backends() -> dict[str, Backend]:
248
+ backend = _backend()
249
+ return {"cpu": CPU(), backend.name: backend}
250
+
251
+
252
+ def _get_cuda() -> Optional[CUDA]:
253
+ """
254
+ Get CUDA runtime library information.
255
+ """
256
+ lib_name = ctypes.util.find_library("cudart")
257
+ if lib_name is None:
258
+ return None
259
+
260
+ try:
261
+ libcudart = ctypes.CDLL(lib_name)
262
+ except OSError:
263
+ return None
264
+
265
+ runtime_version = ctypes.c_int(0)
266
+ result = libcudart.cudaRuntimeGetVersion(ctypes.byref(runtime_version))
267
+ if result != 0:
268
+ warnings.warn("System has CUDA runtime library, but cannot get runtime version.")
269
+ return None
270
+
271
+ # cudaRuntimeGetVersion encodes the version as (major * 1000 + minor * 10).
272
+ version_int = runtime_version.value
273
+ major = version_int // 1000
274
+ minor = (version_int % 1000) // 10
275
+
276
+ return CUDA(version=Version(f"{major}.{minor}"))
@@ -0,0 +1,38 @@
1
+ from typing import Any
2
+
3
+
4
+ class Benchmark:
5
+ """Base class for kernel benchmarks.
6
+
7
+ Subclass this to create a benchmark script with automatic timing,
8
+ verification, and reproducibility support. The kernel is loaded
9
+ automatically from the repo_id specified in the CLI command.
10
+
11
+ Example:
12
+ class MyBenchmark(Benchmark):
13
+ seed = 42
14
+
15
+ def setup(self):
16
+ self.x = torch.randn(128, 1024, device=self.device, dtype=torch.float16)
17
+ self.out = torch.empty(128, 512, device=self.device, dtype=torch.float16)
18
+
19
+ def benchmark_silu(self):
20
+ self.kernel.silu_and_mul(self.out, self.x)
21
+
22
+ def verify_silu(self) -> torch.Tensor:
23
+ # Return reference tensor; runner compares with self.out
24
+ return torch.nn.functional.silu(self.x[..., :512]) * self.x[..., 512:]
25
+
26
+ Run with: kernels benchmark <repo_id>
27
+ """
28
+
29
+ seed: int | None = None # Optional: seed for reproducibility
30
+ device: str = "cpu" # Set automatically by runner
31
+
32
+ def __init__(self) -> None:
33
+ self.kernel: Any = None
34
+ self.out: Any = None # Output tensor, set by setup methods
35
+
36
+ def setup(self) -> None:
37
+ """Override to set up tensors as instance attributes."""
38
+ pass
@@ -0,0 +1,85 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from kernels.benchmark import Benchmark
5
+
6
+
7
+ class SiluAndMulBenchmark(Benchmark):
8
+ seed: int = 42
9
+
10
+ # Workload: small
11
+ def setup_small(self):
12
+ self.x = torch.randn(8, 1024, 2048, device=self.device, dtype=torch.float16)
13
+ self.out = torch.empty(8, 1024, 1024, device=self.device, dtype=torch.float16)
14
+
15
+ def benchmark_small(self):
16
+ self.kernel.silu_and_mul(self.out, self.x)
17
+
18
+ def verify_small(self) -> torch.Tensor:
19
+ d = self.x.shape[-1] // 2
20
+ return F.silu(self.x[..., :d]) * self.x[..., d:]
21
+
22
+ # Workload: medium
23
+ def setup_medium(self):
24
+ self.x = torch.randn(8, 2048, 4096, device=self.device, dtype=torch.float16)
25
+ self.out = torch.empty(8, 2048, 2048, device=self.device, dtype=torch.float16)
26
+
27
+ def benchmark_medium(self):
28
+ self.kernel.silu_and_mul(self.out, self.x)
29
+
30
+ def verify_medium(self) -> torch.Tensor:
31
+ d = self.x.shape[-1] // 2
32
+ return F.silu(self.x[..., :d]) * self.x[..., d:]
33
+
34
+ # Workload: large
35
+ def setup_large(self):
36
+ self.x = torch.randn(8, 4096, 8192, device=self.device, dtype=torch.float16)
37
+ self.out = torch.empty(8, 4096, 4096, device=self.device, dtype=torch.float16)
38
+
39
+ def benchmark_large(self):
40
+ self.kernel.silu_and_mul(self.out, self.x)
41
+ self.kernel.silu_and_mul(self.out, self.x)
42
+
43
+ def verify_large(self) -> torch.Tensor:
44
+ d = self.x.shape[-1] // 2
45
+ return F.silu(self.x[..., :d]) * self.x[..., d:]
46
+
47
+
48
+ class GeluAndMulBenchmark(Benchmark):
49
+ seed: int = 42
50
+
51
+ # Workload: small
52
+ def setup_small(self):
53
+ self.x = torch.randn(8, 1024, 2048, device=self.device, dtype=torch.float16)
54
+ self.out = torch.empty(8, 1024, 1024, device=self.device, dtype=torch.float16)
55
+
56
+ def benchmark_small(self):
57
+ self.kernel.gelu_and_mul(self.out, self.x)
58
+
59
+ def verify_small(self) -> torch.Tensor:
60
+ d = self.x.shape[-1] // 2
61
+ return F.gelu(self.x[..., :d]) * self.x[..., d:]
62
+
63
+ # Workload: medium
64
+ def setup_medium(self):
65
+ self.x = torch.randn(8, 2048, 4096, device=self.device, dtype=torch.float16)
66
+ self.out = torch.empty(8, 2048, 2048, device=self.device, dtype=torch.float16)
67
+
68
+ def benchmark_medium(self):
69
+ self.kernel.gelu_and_mul(self.out, self.x)
70
+
71
+ def verify_medium(self) -> torch.Tensor:
72
+ d = self.x.shape[-1] // 2
73
+ return F.gelu(self.x[..., :d]) * self.x[..., d:]
74
+
75
+ # Workload: large
76
+ def setup_large(self):
77
+ self.x = torch.randn(8, 4096, 8192, device=self.device, dtype=torch.float16)
78
+ self.out = torch.empty(8, 4096, 4096, device=self.device, dtype=torch.float16)
79
+
80
+ def benchmark_large(self):
81
+ self.kernel.gelu_and_mul(self.out, self.x)
82
+
83
+ def verify_large(self) -> torch.Tensor:
84
+ d = self.x.shape[-1] // 2
85
+ return F.gelu(self.x[..., :d]) * self.x[..., d:]