kernels 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: kernels
3
- Version: 0.2.0
4
- Summary: Download cuda kernels
3
+ Version: 0.2.1
4
+ Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  Requires-Python: >=3.9
7
7
  Description-Content-Type: text/markdown
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "kernels"
3
- version = "0.2.0"
4
- description = "Download cuda kernels"
3
+ version = "0.2.1"
4
+ description = "Download compute kernels"
5
5
  authors = [
6
6
  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
7
7
  { name = "Daniel de Kok", email = "daniel@huggingface.co" },
@@ -144,9 +144,18 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
144
144
  return import_from_path(package_name, package_path / package_name / "__init__.py")
145
145
 
146
146
 
147
- def load_kernel(repo_id: str) -> ModuleType:
148
- """Get a pre-downloaded, locked kernel."""
149
- locked_sha = _get_caller_locked_kernel(repo_id)
147
+ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
148
+ """
149
+ Get a pre-downloaded, locked kernel.
150
+
151
+ If `lockfile` is not specified, the lockfile will be loaded from the
152
+ caller's package metadata.
153
+ """
154
+ if lockfile is None:
155
+ locked_sha = _get_caller_locked_kernel(repo_id)
156
+ else:
157
+ with open(lockfile, "r") as f:
158
+ locked_sha = _get_locked_kernel(repo_id, f.read())
150
159
 
151
160
  if locked_sha is None:
152
161
  raise ValueError(
@@ -163,6 +172,7 @@ def load_kernel(repo_id: str) -> ModuleType:
163
172
  repo_id,
164
173
  allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
165
174
  cache_dir=CACHE_DIR,
175
+ revision=locked_sha,
166
176
  local_files_only=True,
167
177
  )
168
178
  )
@@ -200,11 +210,19 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
200
210
  def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
201
211
  for dist in _get_caller_distributions():
202
212
  lock_json = dist.read_text("kernels.lock")
203
- if lock_json is not None:
204
- for kernel_lock_json in json.loads(lock_json):
205
- kernel_lock = KernelLock.from_json(kernel_lock_json)
206
- if kernel_lock.repo_id == repo_id:
207
- return kernel_lock.sha
213
+ if lock_json is None:
214
+ continue
215
+ locked_sha = _get_locked_kernel(repo_id, lock_json)
216
+ if locked_sha is not None:
217
+ return locked_sha
218
+ return None
219
+
220
+
221
+ def _get_locked_kernel(repo_id: str, lock_json: str) -> Optional[str]:
222
+ for kernel_lock_json in json.loads(lock_json):
223
+ kernel_lock = KernelLock.from_json(kernel_lock_json)
224
+ if kernel_lock.repo_id == repo_id:
225
+ return kernel_lock.sha
208
226
  return None
209
227
 
210
228
 
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: kernels
3
- Version: 0.2.0
4
- Summary: Download cuda kernels
3
+ Version: 0.2.1
4
+ Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  Requires-Python: >=3.9
7
7
  Description-Content-Type: text/markdown
@@ -13,4 +13,4 @@ src/kernels.egg-info/requires.txt
13
13
  src/kernels.egg-info/top_level.txt
14
14
  tests/test_basic.py
15
15
  tests/test_benchmarks.py
16
- tests/test_hash_validation.py
16
+ tests/test_kernel_locking.py
@@ -1,6 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from pathlib import Path
3
3
 
4
+ from kernels import load_kernel
4
5
  from kernels.cli import download_kernels
5
6
 
6
7
 
@@ -11,11 +12,13 @@ class DownloadArgs:
11
12
  project_dir: Path
12
13
 
13
14
 
14
- def test_download_hash_validation():
15
- project_dir = Path(__file__).parent / "hash_validation"
16
- download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
17
-
18
-
19
15
  def test_download_all_hash_validation():
20
- project_dir = Path(__file__).parent / "hash_validation"
16
+ project_dir = Path(__file__).parent / "kernel_locking"
21
17
  download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
18
+
19
+
20
+ def test_load_locked():
21
+ project_dir = Path(__file__).parent / "kernel_locking"
22
+ # Also validates that hashing works correctly.
23
+ download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
24
+ load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes