kernels 0.2.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kernels-0.2.0 → kernels-0.3.0}/PKG-INFO +4 -3
- {kernels-0.2.0 → kernels-0.3.0}/README.md +1 -0
- {kernels-0.2.0 → kernels-0.3.0}/pyproject.toml +2 -2
- kernels-0.3.0/src/kernels/__init__.py +23 -0
- kernels-0.3.0/src/kernels/layer.py +231 -0
- {kernels-0.2.0 → kernels-0.3.0}/src/kernels/utils.py +26 -8
- {kernels-0.2.0 → kernels-0.3.0}/src/kernels.egg-info/PKG-INFO +4 -3
- {kernels-0.2.0 → kernels-0.3.0}/src/kernels.egg-info/SOURCES.txt +3 -1
- kernels-0.3.0/src/kernels.egg-info/top_level.txt +1 -0
- kernels-0.2.0/tests/test_hash_validation.py → kernels-0.3.0/tests/test_kernel_locking.py +9 -6
- kernels-0.3.0/tests/test_layer.py +168 -0
- kernels-0.2.0/src/kernels/__init__.py +0 -3
- kernels-0.2.0/src/kernels.egg-info/top_level.txt +0 -2
- {kernels-0.2.0 → kernels-0.3.0}/setup.cfg +0 -0
- {kernels-0.2.0 → kernels-0.3.0}/src/kernels/cli.py +0 -0
- {kernels-0.2.0 → kernels-0.3.0}/src/kernels/compat.py +0 -0
- {kernels-0.2.0 → kernels-0.3.0}/src/kernels/lockfile.py +0 -0
- {kernels-0.2.0 → kernels-0.3.0}/src/kernels.egg-info/dependency_links.txt +0 -0
- {kernels-0.2.0 → kernels-0.3.0}/src/kernels.egg-info/entry_points.txt +0 -0
- {kernels-0.2.0 → kernels-0.3.0}/src/kernels.egg-info/requires.txt +0 -0
- {kernels-0.2.0 → kernels-0.3.0}/tests/test_basic.py +0 -0
- {kernels-0.2.0 → kernels-0.3.0}/tests/test_benchmarks.py +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
2
|
Name: kernels
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary: Download
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Download compute kernels
|
|
5
5
|
Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
|
|
6
6
|
Requires-Python: >=3.9
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
@@ -57,6 +57,7 @@ the Hub.
|
|
|
57
57
|
|
|
58
58
|
## 📚 Documentation
|
|
59
59
|
|
|
60
|
+
- [Using layers](docs/layers.md)
|
|
60
61
|
- [Locking kernel versions](docs/locking.md)
|
|
61
62
|
- [Using kernels in a Docker container](docs/docker.md)
|
|
62
63
|
- [Kernel requirements](docs/kernel-requirements.md)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "kernels"
|
|
3
|
-
version = "0.
|
|
4
|
-
description = "Download
|
|
3
|
+
version = "0.3.0"
|
|
4
|
+
description = "Download compute kernels"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
|
7
7
|
{ name = "Daniel de Kok", email = "daniel@huggingface.co" },
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from kernels.layer import (
|
|
2
|
+
Device,
|
|
3
|
+
LayerRepository,
|
|
4
|
+
register_kernel_mapping,
|
|
5
|
+
use_kernel_forward_from_hub,
|
|
6
|
+
)
|
|
7
|
+
from kernels.utils import (
|
|
8
|
+
get_kernel,
|
|
9
|
+
get_locked_kernel,
|
|
10
|
+
install_kernel,
|
|
11
|
+
load_kernel,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"get_kernel",
|
|
16
|
+
"get_locked_kernel",
|
|
17
|
+
"load_kernel",
|
|
18
|
+
"install_kernel",
|
|
19
|
+
"use_kernel_forward_from_hub",
|
|
20
|
+
"register_kernel_mapping",
|
|
21
|
+
"LayerRepository",
|
|
22
|
+
"Device",
|
|
23
|
+
]
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from contextvars import ContextVar
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import TYPE_CHECKING, Callable, Dict, Union
|
|
6
|
+
|
|
7
|
+
from .utils import get_kernel
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class Device:
|
|
15
|
+
type: str
|
|
16
|
+
|
|
17
|
+
# In the future we might add compute capabilities, etc.
|
|
18
|
+
|
|
19
|
+
def __eq__(self, other):
|
|
20
|
+
return isinstance(other, Device) and self.type == other.type
|
|
21
|
+
|
|
22
|
+
def __hash__(self):
|
|
23
|
+
return hash(self.type)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class LayerRepository:
|
|
28
|
+
"""
|
|
29
|
+
Repository and name of a layer.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
layer_name: str = field(
|
|
33
|
+
metadata={"help": "The name of the layer in the kernel repository."}
|
|
34
|
+
)
|
|
35
|
+
repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
|
|
36
|
+
revision: str = field(
|
|
37
|
+
default="main", metadata={"help": "The revision of the layer."}
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def __eq__(self, other):
|
|
41
|
+
return (
|
|
42
|
+
isinstance(other, LayerRepository)
|
|
43
|
+
and self.layer_name == other.layer_name
|
|
44
|
+
and self.repo_id == other.repo_id
|
|
45
|
+
and self.revision == other.revision
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def __hash__(self):
|
|
49
|
+
return hash((self.layer_name, self.repo_id, self.revision))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
|
|
53
|
+
"_KERNEL_MAPPING", default={}
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def use_kernel_mapping(mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]):
|
|
58
|
+
class ContextManager:
|
|
59
|
+
def __enter__(self):
|
|
60
|
+
# Mappings always stack on previous mappings.
|
|
61
|
+
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
|
|
62
|
+
register_kernel_mapping(mapping)
|
|
63
|
+
|
|
64
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
65
|
+
_KERNEL_MAPPING.reset(self.token)
|
|
66
|
+
|
|
67
|
+
return ContextManager()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def register_kernel_mapping(
|
|
71
|
+
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
|
|
72
|
+
):
|
|
73
|
+
"""
|
|
74
|
+
Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
|
|
75
|
+
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
|
|
76
|
+
Exemple usage:
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
from kernels import LayerRepository, register_kernel_mapping
|
|
80
|
+
|
|
81
|
+
kernel_layer_mapping = {
|
|
82
|
+
"LlamaRMSNorm": {
|
|
83
|
+
"cuda": LayerRepository(
|
|
84
|
+
repo_id="kernels-community/activation",
|
|
85
|
+
layer_name="RmsNorm",
|
|
86
|
+
revision="layers",
|
|
87
|
+
),
|
|
88
|
+
},
|
|
89
|
+
}
|
|
90
|
+
register_kernel_mapping(kernel_layer_mapping)
|
|
91
|
+
```
|
|
92
|
+
"""
|
|
93
|
+
# Merge with existing mappings.
|
|
94
|
+
for new_kernel, new_device_repos in mapping.items():
|
|
95
|
+
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
|
|
96
|
+
for new_device, new_repo in new_device_repos.items():
|
|
97
|
+
if isinstance(new_device, str):
|
|
98
|
+
device_repo[Device(type=new_device)] = new_repo
|
|
99
|
+
else:
|
|
100
|
+
device_repo[new_device] = new_repo
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
|
|
104
|
+
"""
|
|
105
|
+
Replace the forward function of a layer using a layer from the kernel hub.
|
|
106
|
+
This function monkeypatches a layer, replacing the `forward` method
|
|
107
|
+
of the layer with that of a layer from the hub. The replacement is done
|
|
108
|
+
when a layer matching `layer_name` and device type is registered through
|
|
109
|
+
`register_layer_mapping`. The device type is inferred from the first
|
|
110
|
+
argument to `forward`.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
fallback_forward = cls.forward
|
|
114
|
+
|
|
115
|
+
cached_forward: Dict[LayerRepository, Callable] = {}
|
|
116
|
+
|
|
117
|
+
def forward(self, x, **args):
|
|
118
|
+
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
|
119
|
+
if kernel is None:
|
|
120
|
+
if not use_fallback:
|
|
121
|
+
raise ValueError(f"No layer mapping for `{layer_name}`")
|
|
122
|
+
return fallback_forward(self, x, **args)
|
|
123
|
+
|
|
124
|
+
device = getattr(x, "device", None)
|
|
125
|
+
if device is None:
|
|
126
|
+
return fallback_forward(self, x, **args)
|
|
127
|
+
|
|
128
|
+
repo = kernel.get(Device(type=device.type))
|
|
129
|
+
if repo is None:
|
|
130
|
+
if not use_fallback:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
|
|
133
|
+
)
|
|
134
|
+
return fallback_forward(self, x, **args)
|
|
135
|
+
|
|
136
|
+
# Short-circuit if we already loaded the layer.
|
|
137
|
+
layer_forward = cached_forward.get(repo, None)
|
|
138
|
+
if layer_forward is not None:
|
|
139
|
+
return layer_forward(self, x, **args)
|
|
140
|
+
|
|
141
|
+
layer = _get_kernel_layer(
|
|
142
|
+
repo_id=repo.repo_id,
|
|
143
|
+
layer_name=repo.layer_name,
|
|
144
|
+
revision=repo.revision,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# We have to validate against the original signature.
|
|
148
|
+
orig_forward = cls.forward
|
|
149
|
+
try:
|
|
150
|
+
cls.forward = fallback_forward
|
|
151
|
+
_validate_layer(check_cls=cls, cls=layer)
|
|
152
|
+
finally:
|
|
153
|
+
cls.forward = orig_forward
|
|
154
|
+
|
|
155
|
+
layer_forward = layer.forward
|
|
156
|
+
cached_forward[repo] = layer_forward
|
|
157
|
+
|
|
158
|
+
return layer_forward(self, x, **args)
|
|
159
|
+
|
|
160
|
+
cls.forward = forward
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
|
|
164
|
+
"""
|
|
165
|
+
Replace the forward function of a layer using a layer from the kernel hub.
|
|
166
|
+
This decorator can be applied to a layer and replaces the forward method
|
|
167
|
+
of the layer with that of a layer from the hub. The replacement is done
|
|
168
|
+
when a layer matching `layer_name` and device type is registered through
|
|
169
|
+
`register_layer_mapping`. The device type is inferred from the first
|
|
170
|
+
argument to `forward`.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def decorator(cls):
|
|
174
|
+
replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
|
|
175
|
+
return cls
|
|
176
|
+
|
|
177
|
+
return decorator
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
|
|
181
|
+
"""Get a layer from a kernel."""
|
|
182
|
+
|
|
183
|
+
kernel = get_kernel(repo_id, revision=revision)
|
|
184
|
+
|
|
185
|
+
if getattr(kernel, "layers", None) is None:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
layer = getattr(kernel.layers, layer_name, None)
|
|
191
|
+
if layer is None:
|
|
192
|
+
raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
|
|
193
|
+
return layer
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _validate_layer(*, check_cls, cls):
|
|
197
|
+
# The layer must have at least have the following properties: (1) it
|
|
198
|
+
# must be stateless; (2) the forward signature should correspond to
|
|
199
|
+
# the signature it is replacing; (3) forward should not call other
|
|
200
|
+
# methods.
|
|
201
|
+
|
|
202
|
+
from torch import nn
|
|
203
|
+
|
|
204
|
+
if not issubclass(cls, nn.Module):
|
|
205
|
+
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
|
|
206
|
+
|
|
207
|
+
# We verify statelessness by checking that the does not have its own
|
|
208
|
+
# constructor (since the constructor could add member variables)...
|
|
209
|
+
if cls.__init__ is not nn.Module.__init__:
|
|
210
|
+
raise TypeError("Layer must not override nn.Module constructor.")
|
|
211
|
+
|
|
212
|
+
# ... or predefined member variables.
|
|
213
|
+
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
|
214
|
+
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
|
215
|
+
if cls_members - torch_module_members != set():
|
|
216
|
+
raise TypeError("Layer must not contain additional members.")
|
|
217
|
+
|
|
218
|
+
# Check whether the forward signatures are similar.
|
|
219
|
+
params = inspect.signature(cls.forward).parameters
|
|
220
|
+
ref_params = inspect.signature(check_cls.forward).parameters
|
|
221
|
+
|
|
222
|
+
if len(params) != len(ref_params):
|
|
223
|
+
raise TypeError(
|
|
224
|
+
"Forward signature does not match: different number of arguments."
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
for param, ref_param in zip(params.values(), ref_params.values()):
|
|
228
|
+
if param.kind != ref_param.kind:
|
|
229
|
+
raise TypeError(
|
|
230
|
+
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
|
231
|
+
)
|
|
@@ -144,9 +144,18 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
|
|
144
144
|
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
|
145
145
|
|
|
146
146
|
|
|
147
|
-
def load_kernel(repo_id: str) -> ModuleType:
|
|
148
|
-
"""
|
|
149
|
-
|
|
147
|
+
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
|
148
|
+
"""
|
|
149
|
+
Get a pre-downloaded, locked kernel.
|
|
150
|
+
|
|
151
|
+
If `lockfile` is not specified, the lockfile will be loaded from the
|
|
152
|
+
caller's package metadata.
|
|
153
|
+
"""
|
|
154
|
+
if lockfile is None:
|
|
155
|
+
locked_sha = _get_caller_locked_kernel(repo_id)
|
|
156
|
+
else:
|
|
157
|
+
with open(lockfile, "r") as f:
|
|
158
|
+
locked_sha = _get_locked_kernel(repo_id, f.read())
|
|
150
159
|
|
|
151
160
|
if locked_sha is None:
|
|
152
161
|
raise ValueError(
|
|
@@ -163,6 +172,7 @@ def load_kernel(repo_id: str) -> ModuleType:
|
|
|
163
172
|
repo_id,
|
|
164
173
|
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
|
|
165
174
|
cache_dir=CACHE_DIR,
|
|
175
|
+
revision=locked_sha,
|
|
166
176
|
local_files_only=True,
|
|
167
177
|
)
|
|
168
178
|
)
|
|
@@ -200,11 +210,19 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
|
|
|
200
210
|
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
|
|
201
211
|
for dist in _get_caller_distributions():
|
|
202
212
|
lock_json = dist.read_text("kernels.lock")
|
|
203
|
-
if lock_json is
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
213
|
+
if lock_json is None:
|
|
214
|
+
continue
|
|
215
|
+
locked_sha = _get_locked_kernel(repo_id, lock_json)
|
|
216
|
+
if locked_sha is not None:
|
|
217
|
+
return locked_sha
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _get_locked_kernel(repo_id: str, lock_json: str) -> Optional[str]:
|
|
222
|
+
for kernel_lock_json in json.loads(lock_json):
|
|
223
|
+
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
|
224
|
+
if kernel_lock.repo_id == repo_id:
|
|
225
|
+
return kernel_lock.sha
|
|
208
226
|
return None
|
|
209
227
|
|
|
210
228
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
2
|
Name: kernels
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary: Download
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Download compute kernels
|
|
5
5
|
Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
|
|
6
6
|
Requires-Python: >=3.9
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
@@ -57,6 +57,7 @@ the Hub.
|
|
|
57
57
|
|
|
58
58
|
## 📚 Documentation
|
|
59
59
|
|
|
60
|
+
- [Using layers](docs/layers.md)
|
|
60
61
|
- [Locking kernel versions](docs/locking.md)
|
|
61
62
|
- [Using kernels in a Docker container](docs/docker.md)
|
|
62
63
|
- [Kernel requirements](docs/kernel-requirements.md)
|
|
@@ -3,6 +3,7 @@ pyproject.toml
|
|
|
3
3
|
src/kernels/__init__.py
|
|
4
4
|
src/kernels/cli.py
|
|
5
5
|
src/kernels/compat.py
|
|
6
|
+
src/kernels/layer.py
|
|
6
7
|
src/kernels/lockfile.py
|
|
7
8
|
src/kernels/utils.py
|
|
8
9
|
src/kernels.egg-info/PKG-INFO
|
|
@@ -13,4 +14,5 @@ src/kernels.egg-info/requires.txt
|
|
|
13
14
|
src/kernels.egg-info/top_level.txt
|
|
14
15
|
tests/test_basic.py
|
|
15
16
|
tests/test_benchmarks.py
|
|
16
|
-
tests/
|
|
17
|
+
tests/test_kernel_locking.py
|
|
18
|
+
tests/test_layer.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
kernels
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
|
|
4
|
+
from kernels import load_kernel
|
|
4
5
|
from kernels.cli import download_kernels
|
|
5
6
|
|
|
6
7
|
|
|
@@ -11,11 +12,13 @@ class DownloadArgs:
|
|
|
11
12
|
project_dir: Path
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
def test_download_hash_validation():
|
|
15
|
-
project_dir = Path(__file__).parent / "hash_validation"
|
|
16
|
-
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
|
|
17
|
-
|
|
18
|
-
|
|
19
15
|
def test_download_all_hash_validation():
|
|
20
|
-
project_dir = Path(__file__).parent / "
|
|
16
|
+
project_dir = Path(__file__).parent / "kernel_locking"
|
|
21
17
|
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_load_locked():
|
|
21
|
+
project_dir = Path(__file__).parent / "kernel_locking"
|
|
22
|
+
# Also validates that hashing works correctly.
|
|
23
|
+
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
|
|
24
|
+
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from torch.nn import functional as F
|
|
5
|
+
|
|
6
|
+
from kernels import (
|
|
7
|
+
Device,
|
|
8
|
+
LayerRepository,
|
|
9
|
+
register_kernel_mapping,
|
|
10
|
+
use_kernel_forward_from_hub,
|
|
11
|
+
)
|
|
12
|
+
from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
|
|
13
|
+
|
|
14
|
+
kernel_layer_mapping = {
|
|
15
|
+
"SiluAndMul": {
|
|
16
|
+
Device(type="cuda"): LayerRepository(
|
|
17
|
+
repo_id="kernels-community/activation",
|
|
18
|
+
layer_name="SiluAndMul",
|
|
19
|
+
revision="layers",
|
|
20
|
+
)
|
|
21
|
+
},
|
|
22
|
+
"SiluAndMulStringDevice": {
|
|
23
|
+
"cuda": LayerRepository(
|
|
24
|
+
repo_id="kernels-community/activation",
|
|
25
|
+
layer_name="SiluAndMul",
|
|
26
|
+
revision="layers",
|
|
27
|
+
)
|
|
28
|
+
},
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
register_kernel_mapping(kernel_layer_mapping)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SiluAndMul(nn.Module):
|
|
35
|
+
def __init__(self):
|
|
36
|
+
super().__init__()
|
|
37
|
+
# Used to check that we called hub kernel.
|
|
38
|
+
self.n_calls = 0
|
|
39
|
+
|
|
40
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
self.n_calls += 1
|
|
42
|
+
d = input.shape[-1] // 2
|
|
43
|
+
return F.silu(input[..., :d]) * input[..., d:]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@use_kernel_forward_from_hub("SiluAndMul")
|
|
47
|
+
class SiluAndMulWithKernel(SiluAndMul):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@use_kernel_forward_from_hub("SiluAndMulStringDevice")
|
|
52
|
+
class SiluAndMulStringDevice(SiluAndMul):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
|
57
|
+
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
|
58
|
+
def test_hub_forward(cls, device):
|
|
59
|
+
torch.random.manual_seed(0)
|
|
60
|
+
|
|
61
|
+
silu_and_mul = SiluAndMul()
|
|
62
|
+
X = torch.randn((32, 64), device=device)
|
|
63
|
+
Y = silu_and_mul(X)
|
|
64
|
+
|
|
65
|
+
silu_and_mul_with_kernel = cls()
|
|
66
|
+
Y_kernel = silu_and_mul_with_kernel(X)
|
|
67
|
+
|
|
68
|
+
torch.testing.assert_close(Y_kernel, Y)
|
|
69
|
+
|
|
70
|
+
assert silu_and_mul.n_calls == 1
|
|
71
|
+
if device == "cuda":
|
|
72
|
+
assert silu_and_mul_with_kernel.n_calls == 0
|
|
73
|
+
else:
|
|
74
|
+
assert silu_and_mul_with_kernel.n_calls == 1
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_layer_fallback_works():
|
|
78
|
+
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
|
|
79
|
+
class SiluAndMulWithKernelFallback(SiluAndMul):
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
# Check that we don't raise an exception for a non-existing kernel.
|
|
83
|
+
SiluAndMulWithKernelFallback()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def test_mapping_contexts():
|
|
87
|
+
assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
|
|
88
|
+
|
|
89
|
+
extra_mapping1 = {
|
|
90
|
+
"TestKernel": {
|
|
91
|
+
Device(type="cuda"): LayerRepository(
|
|
92
|
+
repo_id="kernels-community/activation",
|
|
93
|
+
layer_name="SiluAndMul",
|
|
94
|
+
revision="layers",
|
|
95
|
+
)
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
with use_kernel_mapping(extra_mapping1):
|
|
100
|
+
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
101
|
+
"SiluAndMul",
|
|
102
|
+
"SiluAndMulStringDevice",
|
|
103
|
+
"TestKernel",
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
extra_mapping2 = {
|
|
107
|
+
"SiluAndMul": {
|
|
108
|
+
Device(type="cuda"): LayerRepository(
|
|
109
|
+
repo_id="kernels-community/non-existing",
|
|
110
|
+
layer_name="SiluAndMul",
|
|
111
|
+
revision="layers",
|
|
112
|
+
)
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
with use_kernel_mapping(extra_mapping2):
|
|
117
|
+
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
118
|
+
"SiluAndMul",
|
|
119
|
+
"SiluAndMulStringDevice",
|
|
120
|
+
"TestKernel",
|
|
121
|
+
}
|
|
122
|
+
assert (
|
|
123
|
+
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
|
124
|
+
== "kernels-community/non-existing"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
128
|
+
"SiluAndMul",
|
|
129
|
+
"SiluAndMulStringDevice",
|
|
130
|
+
"TestKernel",
|
|
131
|
+
}
|
|
132
|
+
assert (
|
|
133
|
+
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
|
134
|
+
== "kernels-community/activation"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
138
|
+
"SiluAndMul",
|
|
139
|
+
"SiluAndMulStringDevice",
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def test_validate_kernel_layer():
|
|
144
|
+
class BadLayer(nn.Module):
|
|
145
|
+
def __init__(self, *args, **kwargs):
|
|
146
|
+
super().__init__(*args, **kwargs)
|
|
147
|
+
self.foo = 42
|
|
148
|
+
|
|
149
|
+
with pytest.raises(TypeError, match="not override"):
|
|
150
|
+
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
|
|
151
|
+
|
|
152
|
+
class BadLayer2(nn.Module):
|
|
153
|
+
foo: int = 42
|
|
154
|
+
|
|
155
|
+
with pytest.raises(TypeError, match="not contain additional members"):
|
|
156
|
+
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
|
|
157
|
+
|
|
158
|
+
class BadLayer3(nn.Module):
|
|
159
|
+
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
|
160
|
+
|
|
161
|
+
with pytest.raises(TypeError, match="different number of arguments"):
|
|
162
|
+
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
|
|
163
|
+
|
|
164
|
+
class BadLayer4(nn.Module):
|
|
165
|
+
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
|
166
|
+
|
|
167
|
+
with pytest.raises(TypeError, match="different kind of arguments"):
|
|
168
|
+
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|