kernels 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.1
2
2
  Name: kernels
3
- Version: 0.2.0
4
- Summary: Download cuda kernels
3
+ Version: 0.3.0
4
+ Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  Requires-Python: >=3.9
7
7
  Description-Content-Type: text/markdown
@@ -57,6 +57,7 @@ the Hub.
57
57
 
58
58
  ## 📚 Documentation
59
59
 
60
+ - [Using layers](docs/layers.md)
60
61
  - [Locking kernel versions](docs/locking.md)
61
62
  - [Using kernels in a Docker container](docs/docker.md)
62
63
  - [Kernel requirements](docs/kernel-requirements.md)
@@ -45,6 +45,7 @@ the Hub.
45
45
 
46
46
  ## 📚 Documentation
47
47
 
48
+ - [Using layers](docs/layers.md)
48
49
  - [Locking kernel versions](docs/locking.md)
49
50
  - [Using kernels in a Docker container](docs/docker.md)
50
51
  - [Kernel requirements](docs/kernel-requirements.md)
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "kernels"
3
- version = "0.2.0"
4
- description = "Download cuda kernels"
3
+ version = "0.3.0"
4
+ description = "Download compute kernels"
5
5
  authors = [
6
6
  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
7
7
  { name = "Daniel de Kok", email = "daniel@huggingface.co" },
@@ -0,0 +1,23 @@
1
+ from kernels.layer import (
2
+ Device,
3
+ LayerRepository,
4
+ register_kernel_mapping,
5
+ use_kernel_forward_from_hub,
6
+ )
7
+ from kernels.utils import (
8
+ get_kernel,
9
+ get_locked_kernel,
10
+ install_kernel,
11
+ load_kernel,
12
+ )
13
+
14
+ __all__ = [
15
+ "get_kernel",
16
+ "get_locked_kernel",
17
+ "load_kernel",
18
+ "install_kernel",
19
+ "use_kernel_forward_from_hub",
20
+ "register_kernel_mapping",
21
+ "LayerRepository",
22
+ "Device",
23
+ ]
@@ -0,0 +1,231 @@
1
+ import inspect
2
+ from contextvars import ContextVar
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING, Callable, Dict, Union
6
+
7
+ from .utils import get_kernel
8
+
9
+ if TYPE_CHECKING:
10
+ from torch import nn
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class Device:
15
+ type: str
16
+
17
+ # In the future we might add compute capabilities, etc.
18
+
19
+ def __eq__(self, other):
20
+ return isinstance(other, Device) and self.type == other.type
21
+
22
+ def __hash__(self):
23
+ return hash(self.type)
24
+
25
+
26
+ @dataclass
27
+ class LayerRepository:
28
+ """
29
+ Repository and name of a layer.
30
+ """
31
+
32
+ layer_name: str = field(
33
+ metadata={"help": "The name of the layer in the kernel repository."}
34
+ )
35
+ repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
36
+ revision: str = field(
37
+ default="main", metadata={"help": "The revision of the layer."}
38
+ )
39
+
40
+ def __eq__(self, other):
41
+ return (
42
+ isinstance(other, LayerRepository)
43
+ and self.layer_name == other.layer_name
44
+ and self.repo_id == other.repo_id
45
+ and self.revision == other.revision
46
+ )
47
+
48
+ def __hash__(self):
49
+ return hash((self.layer_name, self.repo_id, self.revision))
50
+
51
+
52
+ _KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
53
+ "_KERNEL_MAPPING", default={}
54
+ )
55
+
56
+
57
+ def use_kernel_mapping(mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]):
58
+ class ContextManager:
59
+ def __enter__(self):
60
+ # Mappings always stack on previous mappings.
61
+ self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
62
+ register_kernel_mapping(mapping)
63
+
64
+ def __exit__(self, exc_type, exc_value, traceback):
65
+ _KERNEL_MAPPING.reset(self.token)
66
+
67
+ return ContextManager()
68
+
69
+
70
+ def register_kernel_mapping(
71
+ mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
72
+ ):
73
+ """
74
+ Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
75
+ This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
76
+ Exemple usage:
77
+
78
+ ```python
79
+ from kernels import LayerRepository, register_kernel_mapping
80
+
81
+ kernel_layer_mapping = {
82
+ "LlamaRMSNorm": {
83
+ "cuda": LayerRepository(
84
+ repo_id="kernels-community/activation",
85
+ layer_name="RmsNorm",
86
+ revision="layers",
87
+ ),
88
+ },
89
+ }
90
+ register_kernel_mapping(kernel_layer_mapping)
91
+ ```
92
+ """
93
+ # Merge with existing mappings.
94
+ for new_kernel, new_device_repos in mapping.items():
95
+ device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
96
+ for new_device, new_repo in new_device_repos.items():
97
+ if isinstance(new_device, str):
98
+ device_repo[Device(type=new_device)] = new_repo
99
+ else:
100
+ device_repo[new_device] = new_repo
101
+
102
+
103
+ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
104
+ """
105
+ Replace the forward function of a layer using a layer from the kernel hub.
106
+ This function monkeypatches a layer, replacing the `forward` method
107
+ of the layer with that of a layer from the hub. The replacement is done
108
+ when a layer matching `layer_name` and device type is registered through
109
+ `register_layer_mapping`. The device type is inferred from the first
110
+ argument to `forward`.
111
+ """
112
+
113
+ fallback_forward = cls.forward
114
+
115
+ cached_forward: Dict[LayerRepository, Callable] = {}
116
+
117
+ def forward(self, x, **args):
118
+ kernel = _KERNEL_MAPPING.get().get(layer_name)
119
+ if kernel is None:
120
+ if not use_fallback:
121
+ raise ValueError(f"No layer mapping for `{layer_name}`")
122
+ return fallback_forward(self, x, **args)
123
+
124
+ device = getattr(x, "device", None)
125
+ if device is None:
126
+ return fallback_forward(self, x, **args)
127
+
128
+ repo = kernel.get(Device(type=device.type))
129
+ if repo is None:
130
+ if not use_fallback:
131
+ raise ValueError(
132
+ f"No layer mapping for `{layer_name}` with device type `{device.type}`"
133
+ )
134
+ return fallback_forward(self, x, **args)
135
+
136
+ # Short-circuit if we already loaded the layer.
137
+ layer_forward = cached_forward.get(repo, None)
138
+ if layer_forward is not None:
139
+ return layer_forward(self, x, **args)
140
+
141
+ layer = _get_kernel_layer(
142
+ repo_id=repo.repo_id,
143
+ layer_name=repo.layer_name,
144
+ revision=repo.revision,
145
+ )
146
+
147
+ # We have to validate against the original signature.
148
+ orig_forward = cls.forward
149
+ try:
150
+ cls.forward = fallback_forward
151
+ _validate_layer(check_cls=cls, cls=layer)
152
+ finally:
153
+ cls.forward = orig_forward
154
+
155
+ layer_forward = layer.forward
156
+ cached_forward[repo] = layer_forward
157
+
158
+ return layer_forward(self, x, **args)
159
+
160
+ cls.forward = forward
161
+
162
+
163
+ def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
164
+ """
165
+ Replace the forward function of a layer using a layer from the kernel hub.
166
+ This decorator can be applied to a layer and replaces the forward method
167
+ of the layer with that of a layer from the hub. The replacement is done
168
+ when a layer matching `layer_name` and device type is registered through
169
+ `register_layer_mapping`. The device type is inferred from the first
170
+ argument to `forward`.
171
+ """
172
+
173
+ def decorator(cls):
174
+ replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
175
+ return cls
176
+
177
+ return decorator
178
+
179
+
180
+ def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
181
+ """Get a layer from a kernel."""
182
+
183
+ kernel = get_kernel(repo_id, revision=revision)
184
+
185
+ if getattr(kernel, "layers", None) is None:
186
+ raise ValueError(
187
+ f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
188
+ )
189
+
190
+ layer = getattr(kernel.layers, layer_name, None)
191
+ if layer is None:
192
+ raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
193
+ return layer
194
+
195
+
196
+ def _validate_layer(*, check_cls, cls):
197
+ # The layer must have at least have the following properties: (1) it
198
+ # must be stateless; (2) the forward signature should correspond to
199
+ # the signature it is replacing; (3) forward should not call other
200
+ # methods.
201
+
202
+ from torch import nn
203
+
204
+ if not issubclass(cls, nn.Module):
205
+ raise TypeError(f"Layer `{cls}` is not a Torch layer.")
206
+
207
+ # We verify statelessness by checking that the does not have its own
208
+ # constructor (since the constructor could add member variables)...
209
+ if cls.__init__ is not nn.Module.__init__:
210
+ raise TypeError("Layer must not override nn.Module constructor.")
211
+
212
+ # ... or predefined member variables.
213
+ torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
214
+ cls_members = {name for name, _ in inspect.getmembers(cls)}
215
+ if cls_members - torch_module_members != set():
216
+ raise TypeError("Layer must not contain additional members.")
217
+
218
+ # Check whether the forward signatures are similar.
219
+ params = inspect.signature(cls.forward).parameters
220
+ ref_params = inspect.signature(check_cls.forward).parameters
221
+
222
+ if len(params) != len(ref_params):
223
+ raise TypeError(
224
+ "Forward signature does not match: different number of arguments."
225
+ )
226
+
227
+ for param, ref_param in zip(params.values(), ref_params.values()):
228
+ if param.kind != ref_param.kind:
229
+ raise TypeError(
230
+ f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
231
+ )
@@ -144,9 +144,18 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
144
144
  return import_from_path(package_name, package_path / package_name / "__init__.py")
145
145
 
146
146
 
147
- def load_kernel(repo_id: str) -> ModuleType:
148
- """Get a pre-downloaded, locked kernel."""
149
- locked_sha = _get_caller_locked_kernel(repo_id)
147
+ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
148
+ """
149
+ Get a pre-downloaded, locked kernel.
150
+
151
+ If `lockfile` is not specified, the lockfile will be loaded from the
152
+ caller's package metadata.
153
+ """
154
+ if lockfile is None:
155
+ locked_sha = _get_caller_locked_kernel(repo_id)
156
+ else:
157
+ with open(lockfile, "r") as f:
158
+ locked_sha = _get_locked_kernel(repo_id, f.read())
150
159
 
151
160
  if locked_sha is None:
152
161
  raise ValueError(
@@ -163,6 +172,7 @@ def load_kernel(repo_id: str) -> ModuleType:
163
172
  repo_id,
164
173
  allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
165
174
  cache_dir=CACHE_DIR,
175
+ revision=locked_sha,
166
176
  local_files_only=True,
167
177
  )
168
178
  )
@@ -200,11 +210,19 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
200
210
  def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
201
211
  for dist in _get_caller_distributions():
202
212
  lock_json = dist.read_text("kernels.lock")
203
- if lock_json is not None:
204
- for kernel_lock_json in json.loads(lock_json):
205
- kernel_lock = KernelLock.from_json(kernel_lock_json)
206
- if kernel_lock.repo_id == repo_id:
207
- return kernel_lock.sha
213
+ if lock_json is None:
214
+ continue
215
+ locked_sha = _get_locked_kernel(repo_id, lock_json)
216
+ if locked_sha is not None:
217
+ return locked_sha
218
+ return None
219
+
220
+
221
+ def _get_locked_kernel(repo_id: str, lock_json: str) -> Optional[str]:
222
+ for kernel_lock_json in json.loads(lock_json):
223
+ kernel_lock = KernelLock.from_json(kernel_lock_json)
224
+ if kernel_lock.repo_id == repo_id:
225
+ return kernel_lock.sha
208
226
  return None
209
227
 
210
228
 
@@ -1,7 +1,7 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.1
2
2
  Name: kernels
3
- Version: 0.2.0
4
- Summary: Download cuda kernels
3
+ Version: 0.3.0
4
+ Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  Requires-Python: >=3.9
7
7
  Description-Content-Type: text/markdown
@@ -57,6 +57,7 @@ the Hub.
57
57
 
58
58
  ## 📚 Documentation
59
59
 
60
+ - [Using layers](docs/layers.md)
60
61
  - [Locking kernel versions](docs/locking.md)
61
62
  - [Using kernels in a Docker container](docs/docker.md)
62
63
  - [Kernel requirements](docs/kernel-requirements.md)
@@ -3,6 +3,7 @@ pyproject.toml
3
3
  src/kernels/__init__.py
4
4
  src/kernels/cli.py
5
5
  src/kernels/compat.py
6
+ src/kernels/layer.py
6
7
  src/kernels/lockfile.py
7
8
  src/kernels/utils.py
8
9
  src/kernels.egg-info/PKG-INFO
@@ -13,4 +14,5 @@ src/kernels.egg-info/requires.txt
13
14
  src/kernels.egg-info/top_level.txt
14
15
  tests/test_basic.py
15
16
  tests/test_benchmarks.py
16
- tests/test_hash_validation.py
17
+ tests/test_kernel_locking.py
18
+ tests/test_layer.py
@@ -0,0 +1 @@
1
+ kernels
@@ -1,6 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from pathlib import Path
3
3
 
4
+ from kernels import load_kernel
4
5
  from kernels.cli import download_kernels
5
6
 
6
7
 
@@ -11,11 +12,13 @@ class DownloadArgs:
11
12
  project_dir: Path
12
13
 
13
14
 
14
- def test_download_hash_validation():
15
- project_dir = Path(__file__).parent / "hash_validation"
16
- download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
17
-
18
-
19
15
  def test_download_all_hash_validation():
20
- project_dir = Path(__file__).parent / "hash_validation"
16
+ project_dir = Path(__file__).parent / "kernel_locking"
21
17
  download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
18
+
19
+
20
+ def test_load_locked():
21
+ project_dir = Path(__file__).parent / "kernel_locking"
22
+ # Also validates that hashing works correctly.
23
+ download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
24
+ load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
@@ -0,0 +1,168 @@
1
+ import pytest
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ from kernels import (
7
+ Device,
8
+ LayerRepository,
9
+ register_kernel_mapping,
10
+ use_kernel_forward_from_hub,
11
+ )
12
+ from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
13
+
14
+ kernel_layer_mapping = {
15
+ "SiluAndMul": {
16
+ Device(type="cuda"): LayerRepository(
17
+ repo_id="kernels-community/activation",
18
+ layer_name="SiluAndMul",
19
+ revision="layers",
20
+ )
21
+ },
22
+ "SiluAndMulStringDevice": {
23
+ "cuda": LayerRepository(
24
+ repo_id="kernels-community/activation",
25
+ layer_name="SiluAndMul",
26
+ revision="layers",
27
+ )
28
+ },
29
+ }
30
+
31
+ register_kernel_mapping(kernel_layer_mapping)
32
+
33
+
34
+ class SiluAndMul(nn.Module):
35
+ def __init__(self):
36
+ super().__init__()
37
+ # Used to check that we called hub kernel.
38
+ self.n_calls = 0
39
+
40
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
41
+ self.n_calls += 1
42
+ d = input.shape[-1] // 2
43
+ return F.silu(input[..., :d]) * input[..., d:]
44
+
45
+
46
+ @use_kernel_forward_from_hub("SiluAndMul")
47
+ class SiluAndMulWithKernel(SiluAndMul):
48
+ pass
49
+
50
+
51
+ @use_kernel_forward_from_hub("SiluAndMulStringDevice")
52
+ class SiluAndMulStringDevice(SiluAndMul):
53
+ pass
54
+
55
+
56
+ @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
57
+ @pytest.mark.parametrize("device", ["cuda", "cpu"])
58
+ def test_hub_forward(cls, device):
59
+ torch.random.manual_seed(0)
60
+
61
+ silu_and_mul = SiluAndMul()
62
+ X = torch.randn((32, 64), device=device)
63
+ Y = silu_and_mul(X)
64
+
65
+ silu_and_mul_with_kernel = cls()
66
+ Y_kernel = silu_and_mul_with_kernel(X)
67
+
68
+ torch.testing.assert_close(Y_kernel, Y)
69
+
70
+ assert silu_and_mul.n_calls == 1
71
+ if device == "cuda":
72
+ assert silu_and_mul_with_kernel.n_calls == 0
73
+ else:
74
+ assert silu_and_mul_with_kernel.n_calls == 1
75
+
76
+
77
+ def test_layer_fallback_works():
78
+ @use_kernel_forward_from_hub("SiluAndMulNonExisting")
79
+ class SiluAndMulWithKernelFallback(SiluAndMul):
80
+ pass
81
+
82
+ # Check that we don't raise an exception for a non-existing kernel.
83
+ SiluAndMulWithKernelFallback()
84
+
85
+
86
+ def test_mapping_contexts():
87
+ assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
88
+
89
+ extra_mapping1 = {
90
+ "TestKernel": {
91
+ Device(type="cuda"): LayerRepository(
92
+ repo_id="kernels-community/activation",
93
+ layer_name="SiluAndMul",
94
+ revision="layers",
95
+ )
96
+ }
97
+ }
98
+
99
+ with use_kernel_mapping(extra_mapping1):
100
+ assert set(_KERNEL_MAPPING.get().keys()) == {
101
+ "SiluAndMul",
102
+ "SiluAndMulStringDevice",
103
+ "TestKernel",
104
+ }
105
+
106
+ extra_mapping2 = {
107
+ "SiluAndMul": {
108
+ Device(type="cuda"): LayerRepository(
109
+ repo_id="kernels-community/non-existing",
110
+ layer_name="SiluAndMul",
111
+ revision="layers",
112
+ )
113
+ }
114
+ }
115
+
116
+ with use_kernel_mapping(extra_mapping2):
117
+ assert set(_KERNEL_MAPPING.get().keys()) == {
118
+ "SiluAndMul",
119
+ "SiluAndMulStringDevice",
120
+ "TestKernel",
121
+ }
122
+ assert (
123
+ _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
124
+ == "kernels-community/non-existing"
125
+ )
126
+
127
+ assert set(_KERNEL_MAPPING.get().keys()) == {
128
+ "SiluAndMul",
129
+ "SiluAndMulStringDevice",
130
+ "TestKernel",
131
+ }
132
+ assert (
133
+ _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
134
+ == "kernels-community/activation"
135
+ )
136
+
137
+ assert set(_KERNEL_MAPPING.get().keys()) == {
138
+ "SiluAndMul",
139
+ "SiluAndMulStringDevice",
140
+ }
141
+
142
+
143
+ def test_validate_kernel_layer():
144
+ class BadLayer(nn.Module):
145
+ def __init__(self, *args, **kwargs):
146
+ super().__init__(*args, **kwargs)
147
+ self.foo = 42
148
+
149
+ with pytest.raises(TypeError, match="not override"):
150
+ _validate_layer(cls=BadLayer, check_cls=SiluAndMul)
151
+
152
+ class BadLayer2(nn.Module):
153
+ foo: int = 42
154
+
155
+ with pytest.raises(TypeError, match="not contain additional members"):
156
+ _validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
157
+
158
+ class BadLayer3(nn.Module):
159
+ def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
160
+
161
+ with pytest.raises(TypeError, match="different number of arguments"):
162
+ _validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
163
+
164
+ class BadLayer4(nn.Module):
165
+ def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
166
+
167
+ with pytest.raises(TypeError, match="different kind of arguments"):
168
+ _validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
@@ -1,3 +0,0 @@
1
- from kernels.utils import get_kernel, get_locked_kernel, install_kernel, load_kernel
2
-
3
- __all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
@@ -1,2 +0,0 @@
1
- hf_kernels
2
- kernels
File without changes
File without changes
File without changes
File without changes
File without changes