kernels 0.5.0.dev0__tar.gz → 0.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kernels
3
- Version: 0.5.0.dev0
3
+ Version: 0.5.1
4
4
  Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  License: Apache-2.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kernels"
3
- version = "0.5.0.dev0"
3
+ version = "0.5.1"
4
4
  description = "Download compute kernels"
5
5
  authors = [
6
6
  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
@@ -54,11 +54,29 @@ def build_variant() -> str:
54
54
  return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
55
55
 
56
56
 
57
- def universal_build_variant() -> str:
57
+ def build_variant_noarch() -> str:
58
+ import torch
59
+
60
+ if torch.version.cuda is not None:
61
+ return "torch-cuda"
62
+ elif torch.version.hip is not None:
63
+ return "torch-rocm"
64
+ elif torch.backends.mps.is_available():
65
+ return "torch-metal"
66
+ else:
67
+ return "torch-cpu"
68
+
69
+
70
+ def build_variant_universal() -> str:
58
71
  # Once we support other frameworks, detection goes here.
59
72
  return "torch-universal"
60
73
 
61
74
 
75
+ def build_variants() -> List[str]:
76
+ """Return compatible build variants in preferred order."""
77
+ return [build_variant(), build_variant_noarch(), build_variant_universal()]
78
+
79
+
62
80
  def import_from_path(module_name: str, file_path: Path) -> ModuleType:
63
81
  # We cannot use the module name as-is, after adding it to `sys.modules`,
64
82
  # it would also be used for other imports. So, we make a module name that
@@ -89,25 +107,32 @@ def install_kernel(
89
107
  The output path is validated againt `hash` when set.
90
108
  """
91
109
  package_name = package_name_from_repo_id(repo_id)
92
- variant = build_variant()
93
- universal_variant = universal_build_variant()
110
+ allow_patterns = [f"build/{variant}/*" for variant in build_variants()]
94
111
  repo_path = Path(
95
112
  snapshot_download(
96
113
  repo_id,
97
- allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
114
+ allow_patterns=allow_patterns,
98
115
  cache_dir=CACHE_DIR,
99
116
  revision=revision,
100
117
  local_files_only=local_files_only,
101
118
  )
102
119
  )
103
120
 
104
- variant_path = repo_path / "build" / variant
105
- universal_variant_path = repo_path / "build" / universal_variant
121
+ variants = build_variants()
122
+ variant = None
123
+ variant_path = None
124
+ for candidate_variant in variants:
125
+ variant_path = repo_path / "build" / candidate_variant
126
+ if variant_path.exists():
127
+ variant = candidate_variant
128
+ break
106
129
 
107
- if not variant_path.exists() and universal_variant_path.exists():
108
- # Fall back to universal variant.
109
- variant = universal_variant
110
- variant_path = universal_variant_path
130
+ if variant is None:
131
+ raise FileNotFoundError(
132
+ f"Kernel at path `{repo_path}` does not have one of build variants: {', '.join(variants)}"
133
+ )
134
+
135
+ assert variant_path is not None
111
136
 
112
137
  if variant_locks is not None:
113
138
  variant_lock = variant_locks.get(variant)
@@ -167,21 +192,16 @@ def has_kernel(repo_id: str, revision: str = "main") -> bool:
167
192
  (Torch version and compute framework).
168
193
  """
169
194
  package_name = package_name_from_repo_id(repo_id)
170
- variant = build_variant()
171
- universal_variant = universal_build_variant()
172
-
173
- if file_exists(
174
- repo_id,
175
- revision=revision,
176
- filename=f"build/{universal_variant}/{package_name}/__init__.py",
177
- ):
178
- return True
179
-
180
- return file_exists(
181
- repo_id,
182
- revision=revision,
183
- filename=f"build/{variant}/{package_name}/__init__.py",
184
- )
195
+
196
+ for variant in build_variants():
197
+ if file_exists(
198
+ repo_id,
199
+ revision=revision,
200
+ filename=f"build/{variant}/{package_name}/__init__.py",
201
+ ):
202
+ return True
203
+
204
+ return False
185
205
 
186
206
 
187
207
  def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
@@ -204,33 +224,29 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
204
224
 
205
225
  package_name = package_name_from_repo_id(repo_id)
206
226
 
207
- variant = build_variant()
208
- universal_variant = universal_build_variant()
209
-
227
+ allow_patterns = [f"build/{variant}/*" for variant in build_variants()]
210
228
  repo_path = Path(
211
229
  snapshot_download(
212
230
  repo_id,
213
- allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
231
+ allow_patterns=allow_patterns,
214
232
  cache_dir=CACHE_DIR,
215
233
  revision=locked_sha,
216
234
  local_files_only=True,
217
235
  )
218
236
  )
219
237
 
220
- variant_path = repo_path / "build" / variant
221
- universal_variant_path = repo_path / "build" / universal_variant
222
- if not variant_path.exists() and universal_variant_path.exists():
223
- # Fall back to universal variant.
224
- variant = universal_variant
225
- variant_path = universal_variant_path
226
-
227
- module_init_path = variant_path / package_name / "__init__.py"
228
- if not os.path.exists(module_init_path):
229
- raise FileNotFoundError(
230
- f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
231
- )
238
+ for variant in build_variants():
239
+ variant_path = repo_path / "build" / variant
240
+ module_init_path = variant_path / package_name / "__init__.py"
241
+ if module_init_path.exists():
242
+ module_init_path = variant_path / package_name / "__init__.py"
243
+ return import_from_path(
244
+ package_name, variant_path / package_name / "__init__.py"
245
+ )
232
246
 
233
- return import_from_path(package_name, variant_path / package_name / "__init__.py")
247
+ raise FileNotFoundError(
248
+ f"Locked kernel `{repo_id}` does not have applicable variant or was not downloaded with `kernels download <project>`"
249
+ )
234
250
 
235
251
 
236
252
  def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kernels
3
- Version: 0.5.0.dev0
3
+ Version: 0.5.1
4
4
  Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  License: Apache-2.0
@@ -6,7 +6,7 @@ from kernels import get_kernel, has_kernel
6
6
 
7
7
  @pytest.fixture
8
8
  def kernel():
9
- return get_kernel("kernels-community/activation")
9
+ return get_kernel("kernels-community/activation", revision="v0.0.3")
10
10
 
11
11
 
12
12
  @pytest.fixture
@@ -39,7 +39,7 @@ def test_gelu_fast(kernel, device):
39
39
  @pytest.mark.parametrize(
40
40
  "kernel_exists",
41
41
  [
42
- ("kernels-community/activation", "main", True),
42
+ ("kernels-community/activation", "v0.0.3", True),
43
43
  ("kernels-community/triton-layer-norm", "main", True),
44
44
  # Repo only contains Torch 2.4 kernels (and we don't
45
45
  # support/test against this version).
@@ -64,3 +64,10 @@ def test_universal_kernel(universal_kernel):
64
64
  out_check = out_check.to(torch.float16)
65
65
 
66
66
  torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1)
67
+
68
+
69
+ def test_noarch_kernel(device):
70
+ supported_devices = ["cpu", "cuda"]
71
+ if device not in supported_devices:
72
+ pytest.skip(f"Device is not one of: {','.join(supported_devices)}")
73
+ get_kernel("kernels-test/silu-and-mul-noarch")
@@ -6,7 +6,7 @@ from kernels import get_kernel
6
6
 
7
7
  @pytest.fixture
8
8
  def kernel():
9
- return get_kernel("kernels-community/activation")
9
+ return get_kernel("kernels-community/activation", revision="v0.0.3")
10
10
 
11
11
 
12
12
  @pytest.fixture
File without changes
File without changes
File without changes
File without changes