kernels 0.4.3__tar.gz → 0.5.0.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kernels-0.4.3/src/kernels.egg-info → kernels-0.5.0.dev0}/PKG-INFO +14 -2
- {kernels-0.4.3 → kernels-0.5.0.dev0}/README.md +11 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/pyproject.toml +1 -1
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels/__init__.py +2 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels/layer.py +47 -9
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels/utils.py +24 -1
- {kernels-0.4.3 → kernels-0.5.0.dev0/src/kernels.egg-info}/PKG-INFO +14 -2
- {kernels-0.4.3 → kernels-0.5.0.dev0}/tests/test_basic.py +17 -1
- {kernels-0.4.3 → kernels-0.5.0.dev0}/tests/test_layer.py +110 -1
- {kernels-0.4.3 → kernels-0.5.0.dev0}/LICENSE +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/setup.cfg +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels/cli.py +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels/compat.py +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels/lockfile.py +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels.egg-info/SOURCES.txt +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels.egg-info/dependency_links.txt +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels.egg-info/entry_points.txt +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels.egg-info/requires.txt +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/src/kernels.egg-info/top_level.txt +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/tests/test_benchmarks.py +0 -0
- {kernels-0.4.3 → kernels-0.5.0.dev0}/tests/test_kernel_locking.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: kernels
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0.dev0
|
|
4
4
|
Summary: Download compute kernels
|
|
5
5
|
Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -12,9 +12,21 @@ Requires-Dist: packaging>=20.0
|
|
|
12
12
|
Requires-Dist: tomli>=2.0; python_version < "3.11"
|
|
13
13
|
Provides-Extra: torch
|
|
14
14
|
Requires-Dist: torch; extra == "torch"
|
|
15
|
+
Dynamic: license-file
|
|
15
16
|
|
|
16
17
|
# kernels
|
|
17
18
|
|
|
19
|
+
<div align="center">
|
|
20
|
+
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
|
21
|
+
<p align="center">
|
|
22
|
+
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
|
|
23
|
+
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
|
|
24
|
+
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
|
|
25
|
+
|
|
26
|
+
</p>
|
|
27
|
+
</div>
|
|
28
|
+
<hr/>
|
|
29
|
+
|
|
18
30
|
The Kernel Hub allows Python libraries and applications to load compute
|
|
19
31
|
kernels directly from the [Hub](https://hf.co/). To support this kind
|
|
20
32
|
of dynamic loading, Hub kernels differ from traditional Python kernel
|
|
@@ -1,5 +1,16 @@
|
|
|
1
1
|
# kernels
|
|
2
2
|
|
|
3
|
+
<div align="center">
|
|
4
|
+
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
|
5
|
+
<p align="center">
|
|
6
|
+
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
|
|
7
|
+
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
|
|
8
|
+
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
|
|
9
|
+
|
|
10
|
+
</p>
|
|
11
|
+
</div>
|
|
12
|
+
<hr/>
|
|
13
|
+
|
|
3
14
|
The Kernel Hub allows Python libraries and applications to load compute
|
|
4
15
|
kernels directly from the [Hub](https://hf.co/). To support this kind
|
|
5
16
|
of dynamic loading, Hub kernels differ from traditional Python kernel
|
|
@@ -9,6 +9,7 @@ from kernels.layer import (
|
|
|
9
9
|
from kernels.utils import (
|
|
10
10
|
get_kernel,
|
|
11
11
|
get_locked_kernel,
|
|
12
|
+
has_kernel,
|
|
12
13
|
install_kernel,
|
|
13
14
|
load_kernel,
|
|
14
15
|
)
|
|
@@ -16,6 +17,7 @@ from kernels.utils import (
|
|
|
16
17
|
__all__ = [
|
|
17
18
|
"get_kernel",
|
|
18
19
|
"get_locked_kernel",
|
|
20
|
+
"has_kernel",
|
|
19
21
|
"load_kernel",
|
|
20
22
|
"install_kernel",
|
|
21
23
|
"use_kernel_forward_from_hub",
|
|
@@ -4,7 +4,7 @@ import warnings
|
|
|
4
4
|
from contextvars import ContextVar
|
|
5
5
|
from copy import deepcopy
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import TYPE_CHECKING,
|
|
7
|
+
from typing import TYPE_CHECKING, Dict, Union
|
|
8
8
|
|
|
9
9
|
from .utils import get_kernel
|
|
10
10
|
|
|
@@ -131,12 +131,15 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
|
|
131
131
|
|
|
132
132
|
fallback_forward = cls.forward
|
|
133
133
|
|
|
134
|
-
|
|
134
|
+
cached_layer: Dict[LayerRepository, nn.Module] = {}
|
|
135
135
|
|
|
136
136
|
def forward(self, x, *args, **kwargs):
|
|
137
137
|
if _DISABLE_KERNEL_MAPPING:
|
|
138
138
|
return fallback_forward(self, x, *args, **kwargs)
|
|
139
139
|
|
|
140
|
+
needs_backward = self.training
|
|
141
|
+
is_compiling = _is_torchdynamo_compiling()
|
|
142
|
+
|
|
140
143
|
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
|
141
144
|
if kernel is None:
|
|
142
145
|
warnings.warn(
|
|
@@ -162,9 +165,18 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
|
|
162
165
|
return fallback_forward(self, x, *args, **kwargs)
|
|
163
166
|
|
|
164
167
|
# Short-circuit if we already loaded the layer.
|
|
165
|
-
|
|
166
|
-
if
|
|
167
|
-
|
|
168
|
+
layer = cached_layer.get(repo, None)
|
|
169
|
+
if layer is not None:
|
|
170
|
+
# Switch to fallback when the layer does not support:
|
|
171
|
+
# compilation/compile when needed.
|
|
172
|
+
# backward when needed
|
|
173
|
+
needs_fallback = needs_backward and not getattr(layer, "has_backward", True)
|
|
174
|
+
needs_fallback |= is_compiling and not getattr(
|
|
175
|
+
layer, "can_torch_compile", False
|
|
176
|
+
)
|
|
177
|
+
if needs_fallback:
|
|
178
|
+
return fallback_forward(self, x, *args, **kwargs)
|
|
179
|
+
return layer.forward(self, x, *args, **kwargs)
|
|
168
180
|
|
|
169
181
|
layer = _get_kernel_layer(
|
|
170
182
|
repo_id=repo.repo_id,
|
|
@@ -180,10 +192,18 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
|
|
180
192
|
finally:
|
|
181
193
|
cls.forward = orig_forward
|
|
182
194
|
|
|
183
|
-
|
|
184
|
-
|
|
195
|
+
cached_layer[repo] = layer
|
|
196
|
+
|
|
197
|
+
# Switch to fallback when the layer does not support
|
|
198
|
+
# compilation/compile when needed.
|
|
199
|
+
needs_fallback = needs_backward and not getattr(layer, "has_backward", True)
|
|
200
|
+
needs_fallback |= is_compiling and not getattr(
|
|
201
|
+
layer, "can_torch_compile", False
|
|
202
|
+
)
|
|
203
|
+
if needs_fallback:
|
|
204
|
+
return fallback_forward(self, x, *args, **kwargs)
|
|
185
205
|
|
|
186
|
-
return
|
|
206
|
+
return layer.forward(self, x, *args, **kwargs)
|
|
187
207
|
|
|
188
208
|
cls.forward = forward
|
|
189
209
|
|
|
@@ -240,7 +260,9 @@ def _validate_layer(*, check_cls, cls):
|
|
|
240
260
|
# ... or predefined member variables.
|
|
241
261
|
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
|
242
262
|
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
|
243
|
-
|
|
263
|
+
difference = cls_members - torch_module_members
|
|
264
|
+
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
|
|
265
|
+
if not difference <= {"can_torch_compile", "has_backward"}:
|
|
244
266
|
raise TypeError("Layer must not contain additional members.")
|
|
245
267
|
|
|
246
268
|
# Check whether the forward signatures are similar.
|
|
@@ -257,3 +279,19 @@ def _validate_layer(*, check_cls, cls):
|
|
|
257
279
|
raise TypeError(
|
|
258
280
|
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
|
259
281
|
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _is_torchdynamo_compiling():
|
|
285
|
+
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622)
|
|
286
|
+
# hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3)
|
|
287
|
+
try:
|
|
288
|
+
import torch
|
|
289
|
+
|
|
290
|
+
return torch.compiler.is_compiling()
|
|
291
|
+
except Exception:
|
|
292
|
+
try:
|
|
293
|
+
import torch._dynamo as dynamo # noqa: F401
|
|
294
|
+
|
|
295
|
+
return dynamo.is_compiling()
|
|
296
|
+
except Exception:
|
|
297
|
+
return False
|
|
@@ -13,7 +13,7 @@ from pathlib import Path
|
|
|
13
13
|
from types import ModuleType
|
|
14
14
|
from typing import Dict, List, Optional, Tuple
|
|
15
15
|
|
|
16
|
-
from huggingface_hub import snapshot_download
|
|
16
|
+
from huggingface_hub import file_exists, snapshot_download
|
|
17
17
|
from packaging.version import parse
|
|
18
18
|
|
|
19
19
|
from kernels.lockfile import KernelLock, VariantLock
|
|
@@ -161,6 +161,29 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
|
|
161
161
|
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
|
162
162
|
|
|
163
163
|
|
|
164
|
+
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
|
165
|
+
"""
|
|
166
|
+
Check whether a kernel build exists for the current environment
|
|
167
|
+
(Torch version and compute framework).
|
|
168
|
+
"""
|
|
169
|
+
package_name = package_name_from_repo_id(repo_id)
|
|
170
|
+
variant = build_variant()
|
|
171
|
+
universal_variant = universal_build_variant()
|
|
172
|
+
|
|
173
|
+
if file_exists(
|
|
174
|
+
repo_id,
|
|
175
|
+
revision=revision,
|
|
176
|
+
filename=f"build/{universal_variant}/{package_name}/__init__.py",
|
|
177
|
+
):
|
|
178
|
+
return True
|
|
179
|
+
|
|
180
|
+
return file_exists(
|
|
181
|
+
repo_id,
|
|
182
|
+
revision=revision,
|
|
183
|
+
filename=f"build/{variant}/{package_name}/__init__.py",
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
164
187
|
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
|
165
188
|
"""
|
|
166
189
|
Get a pre-downloaded, locked kernel.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: kernels
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0.dev0
|
|
4
4
|
Summary: Download compute kernels
|
|
5
5
|
Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -12,9 +12,21 @@ Requires-Dist: packaging>=20.0
|
|
|
12
12
|
Requires-Dist: tomli>=2.0; python_version < "3.11"
|
|
13
13
|
Provides-Extra: torch
|
|
14
14
|
Requires-Dist: torch; extra == "torch"
|
|
15
|
+
Dynamic: license-file
|
|
15
16
|
|
|
16
17
|
# kernels
|
|
17
18
|
|
|
19
|
+
<div align="center">
|
|
20
|
+
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
|
21
|
+
<p align="center">
|
|
22
|
+
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
|
|
23
|
+
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
|
|
24
|
+
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
|
|
25
|
+
|
|
26
|
+
</p>
|
|
27
|
+
</div>
|
|
28
|
+
<hr/>
|
|
29
|
+
|
|
18
30
|
The Kernel Hub allows Python libraries and applications to load compute
|
|
19
31
|
kernels directly from the [Hub](https://hf.co/). To support this kind
|
|
20
32
|
of dynamic loading, Hub kernels differ from traditional Python kernel
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
|
-
from kernels import get_kernel
|
|
4
|
+
from kernels import get_kernel, has_kernel
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
@pytest.fixture
|
|
@@ -36,6 +36,22 @@ def test_gelu_fast(kernel, device):
|
|
|
36
36
|
assert torch.allclose(y, expected)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
@pytest.mark.parametrize(
|
|
40
|
+
"kernel_exists",
|
|
41
|
+
[
|
|
42
|
+
("kernels-community/activation", "main", True),
|
|
43
|
+
("kernels-community/triton-layer-norm", "main", True),
|
|
44
|
+
# Repo only contains Torch 2.4 kernels (and we don't
|
|
45
|
+
# support/test against this version).
|
|
46
|
+
("kernels-test/only-torch-2.4", "main", False),
|
|
47
|
+
("google-bert/bert-base-uncased", "87565a309", False),
|
|
48
|
+
],
|
|
49
|
+
)
|
|
50
|
+
def test_has_kernel(kernel_exists):
|
|
51
|
+
repo_id, revision, kernel = kernel_exists
|
|
52
|
+
assert has_kernel(repo_id, revision=revision) == kernel
|
|
53
|
+
|
|
54
|
+
|
|
39
55
|
def test_universal_kernel(universal_kernel):
|
|
40
56
|
torch.manual_seed(0)
|
|
41
57
|
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
|
@@ -19,6 +19,12 @@ kernel_layer_mapping = {
|
|
|
19
19
|
revision="layers",
|
|
20
20
|
)
|
|
21
21
|
},
|
|
22
|
+
"SiluAndMulNoCompile": {
|
|
23
|
+
"cuda": LayerRepository(
|
|
24
|
+
repo_id="kernels-test/op-without-fake-test",
|
|
25
|
+
layer_name="SiluAndMul",
|
|
26
|
+
)
|
|
27
|
+
},
|
|
22
28
|
"SiluAndMulStringDevice": {
|
|
23
29
|
"cuda": LayerRepository(
|
|
24
30
|
repo_id="kernels-community/activation",
|
|
@@ -43,6 +49,11 @@ class SiluAndMul(nn.Module):
|
|
|
43
49
|
return F.silu(input[..., :d]) * input[..., d:]
|
|
44
50
|
|
|
45
51
|
|
|
52
|
+
@use_kernel_forward_from_hub("SiluAndMulNoCompile")
|
|
53
|
+
class SiluAndMulNoCompileKernel(SiluAndMul):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
46
57
|
@use_kernel_forward_from_hub("SiluAndMul")
|
|
47
58
|
class SiluAndMulWithKernel(SiluAndMul):
|
|
48
59
|
pass
|
|
@@ -101,8 +112,29 @@ def test_layer_fallback_works():
|
|
|
101
112
|
SiluAndMulWithKernelFallback()
|
|
102
113
|
|
|
103
114
|
|
|
115
|
+
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
|
116
|
+
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
|
117
|
+
def test_torch_compile_layer(cls, device):
|
|
118
|
+
silu_and_mul = SiluAndMul()
|
|
119
|
+
|
|
120
|
+
X = torch.randn((32, 64), dtype=torch.float32, device=device)
|
|
121
|
+
Y = silu_and_mul(X)
|
|
122
|
+
|
|
123
|
+
silu_and_mul_with_kernel = cls()
|
|
124
|
+
silu_and_mul_with_kernel.eval()
|
|
125
|
+
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel)
|
|
126
|
+
|
|
127
|
+
Y_compiled = silu_and_mul_compiled(X)
|
|
128
|
+
|
|
129
|
+
torch.testing.assert_close(Y_compiled, Y)
|
|
130
|
+
|
|
131
|
+
|
|
104
132
|
def test_mapping_contexts():
|
|
105
|
-
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
133
|
+
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
134
|
+
"SiluAndMul",
|
|
135
|
+
"SiluAndMulStringDevice",
|
|
136
|
+
"SiluAndMulNoCompile",
|
|
137
|
+
}
|
|
106
138
|
|
|
107
139
|
extra_mapping1 = {
|
|
108
140
|
"TestKernel": {
|
|
@@ -118,6 +150,7 @@ def test_mapping_contexts():
|
|
|
118
150
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
119
151
|
"SiluAndMul",
|
|
120
152
|
"SiluAndMulStringDevice",
|
|
153
|
+
"SiluAndMulNoCompile",
|
|
121
154
|
"TestKernel",
|
|
122
155
|
}
|
|
123
156
|
|
|
@@ -135,6 +168,7 @@ def test_mapping_contexts():
|
|
|
135
168
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
136
169
|
"SiluAndMul",
|
|
137
170
|
"SiluAndMulStringDevice",
|
|
171
|
+
"SiluAndMulNoCompile",
|
|
138
172
|
"TestKernel",
|
|
139
173
|
}
|
|
140
174
|
assert (
|
|
@@ -145,6 +179,7 @@ def test_mapping_contexts():
|
|
|
145
179
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
146
180
|
"SiluAndMul",
|
|
147
181
|
"SiluAndMulStringDevice",
|
|
182
|
+
"SiluAndMulNoCompile",
|
|
148
183
|
"TestKernel",
|
|
149
184
|
}
|
|
150
185
|
assert (
|
|
@@ -164,6 +199,7 @@ def test_mapping_contexts():
|
|
|
164
199
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
165
200
|
"SiluAndMul",
|
|
166
201
|
"SiluAndMulStringDevice",
|
|
202
|
+
"SiluAndMulNoCompile",
|
|
167
203
|
"TestKernel",
|
|
168
204
|
}
|
|
169
205
|
assert (
|
|
@@ -174,6 +210,7 @@ def test_mapping_contexts():
|
|
|
174
210
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
175
211
|
"SiluAndMul",
|
|
176
212
|
"SiluAndMulStringDevice",
|
|
213
|
+
"SiluAndMulNoCompile",
|
|
177
214
|
}
|
|
178
215
|
|
|
179
216
|
|
|
@@ -203,3 +240,75 @@ def test_validate_kernel_layer():
|
|
|
203
240
|
|
|
204
241
|
with pytest.raises(TypeError, match="different kind of arguments"):
|
|
205
242
|
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def test_fallback_used_when_training():
|
|
246
|
+
@use_kernel_forward_from_hub("Linear")
|
|
247
|
+
class TorchLinear(nn.Linear):
|
|
248
|
+
def __init__(self, *args, **kwargs):
|
|
249
|
+
super().__init__(*args, **kwargs)
|
|
250
|
+
# Used to check that we called hub kernel.
|
|
251
|
+
self.n_calls = 0
|
|
252
|
+
|
|
253
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
254
|
+
self.n_calls += 1
|
|
255
|
+
return super().forward(input)
|
|
256
|
+
|
|
257
|
+
linear = TorchLinear(32, 32).to("cuda")
|
|
258
|
+
|
|
259
|
+
with use_kernel_mapping(
|
|
260
|
+
{
|
|
261
|
+
"Linear": {
|
|
262
|
+
Device(type="cuda"): LayerRepository(
|
|
263
|
+
repo_id="kernels-test/backward-marker-test",
|
|
264
|
+
layer_name="LinearImplicitBackward",
|
|
265
|
+
)
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
):
|
|
269
|
+
linear.train()
|
|
270
|
+
X = torch.randn(10, 32, device="cuda")
|
|
271
|
+
linear(X)
|
|
272
|
+
assert linear.n_calls == 0
|
|
273
|
+
|
|
274
|
+
linear.eval()
|
|
275
|
+
linear(X)
|
|
276
|
+
assert linear.n_calls == 0
|
|
277
|
+
|
|
278
|
+
with use_kernel_mapping(
|
|
279
|
+
{
|
|
280
|
+
"Linear": {
|
|
281
|
+
Device(type="cuda"): LayerRepository(
|
|
282
|
+
repo_id="kernels-test/backward-marker-test",
|
|
283
|
+
layer_name="LinearBackward",
|
|
284
|
+
)
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
):
|
|
288
|
+
linear.train()
|
|
289
|
+
X = torch.randn(10, 32, device="cuda")
|
|
290
|
+
linear(X)
|
|
291
|
+
assert linear.n_calls == 0
|
|
292
|
+
|
|
293
|
+
linear.eval()
|
|
294
|
+
linear(X)
|
|
295
|
+
assert linear.n_calls == 0
|
|
296
|
+
|
|
297
|
+
with use_kernel_mapping(
|
|
298
|
+
{
|
|
299
|
+
"Linear": {
|
|
300
|
+
Device(type="cuda"): LayerRepository(
|
|
301
|
+
repo_id="kernels-test/backward-marker-test",
|
|
302
|
+
layer_name="LinearNoBackward",
|
|
303
|
+
)
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
):
|
|
307
|
+
linear.train()
|
|
308
|
+
X = torch.randn(10, 32, device="cuda")
|
|
309
|
+
linear(X)
|
|
310
|
+
assert linear.n_calls == 1
|
|
311
|
+
|
|
312
|
+
linear.eval()
|
|
313
|
+
linear(X)
|
|
314
|
+
assert linear.n_calls == 1
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|