kernels 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: kernels
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  Requires-Python: >=3.9
@@ -8,7 +8,8 @@ Description-Content-Type: text/markdown
8
8
  Requires-Dist: huggingface-hub>=0.26.3
9
9
  Requires-Dist: packaging>=24.2
10
10
  Requires-Dist: tomli>=2.0.1; python_version < "3.11"
11
- Requires-Dist: torch>=2.5
11
+ Provides-Extra: torch
12
+ Requires-Dist: torch; extra == "torch"
12
13
 
13
14
  # kernels
14
15
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kernels"
3
- version = "0.3.0"
3
+ version = "0.3.2"
4
4
  description = "Download compute kernels"
5
5
  authors = [
6
6
  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
@@ -14,7 +14,6 @@ dependencies = [
14
14
  "huggingface-hub>=0.26.3",
15
15
  "packaging>=24.2",
16
16
  "tomli>=2.0.1; python_version<'3.11'",
17
- "torch>=2.5",
18
17
  ]
19
18
 
20
19
  [build-system]
@@ -27,8 +26,12 @@ dev = [
27
26
  "pytest >=8",
28
27
  # Whatever version is compatible with pytest.
29
28
  "pytest-benchmark",
29
+ "torch >=2.5",
30
30
  ]
31
31
 
32
+ [project.optional-dependencies]
33
+ torch = ["torch"]
34
+
32
35
  [project.scripts]
33
36
  kernels = "kernels.cli:main"
34
37
 
@@ -2,6 +2,7 @@ from kernels.layer import (
2
2
  Device,
3
3
  LayerRepository,
4
4
  register_kernel_mapping,
5
+ replace_kernel_forward_from_hub,
5
6
  use_kernel_forward_from_hub,
6
7
  )
7
8
  from kernels.utils import (
@@ -18,6 +19,7 @@ __all__ = [
18
19
  "install_kernel",
19
20
  "use_kernel_forward_from_hub",
20
21
  "register_kernel_mapping",
22
+ "replace_kernel_forward_from_hub",
21
23
  "LayerRepository",
22
24
  "Device",
23
25
  ]
@@ -114,16 +114,16 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
114
114
 
115
115
  cached_forward: Dict[LayerRepository, Callable] = {}
116
116
 
117
- def forward(self, x, **args):
117
+ def forward(self, x, *args, **kwargs):
118
118
  kernel = _KERNEL_MAPPING.get().get(layer_name)
119
119
  if kernel is None:
120
120
  if not use_fallback:
121
121
  raise ValueError(f"No layer mapping for `{layer_name}`")
122
- return fallback_forward(self, x, **args)
122
+ return fallback_forward(self, x, *args, **kwargs)
123
123
 
124
124
  device = getattr(x, "device", None)
125
125
  if device is None:
126
- return fallback_forward(self, x, **args)
126
+ return fallback_forward(self, x, *args, **kwargs)
127
127
 
128
128
  repo = kernel.get(Device(type=device.type))
129
129
  if repo is None:
@@ -131,12 +131,12 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
131
131
  raise ValueError(
132
132
  f"No layer mapping for `{layer_name}` with device type `{device.type}`"
133
133
  )
134
- return fallback_forward(self, x, **args)
134
+ return fallback_forward(self, x, *args, **kwargs)
135
135
 
136
136
  # Short-circuit if we already loaded the layer.
137
137
  layer_forward = cached_forward.get(repo, None)
138
138
  if layer_forward is not None:
139
- return layer_forward(self, x, **args)
139
+ return layer_forward(self, x, *args, **kwargs)
140
140
 
141
141
  layer = _get_kernel_layer(
142
142
  repo_id=repo.repo_id,
@@ -155,7 +155,7 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
155
155
  layer_forward = layer.forward
156
156
  cached_forward[repo] = layer_forward
157
157
 
158
- return layer_forward(self, x, **args)
158
+ return layer_forward(self, x, *args, **kwargs)
159
159
 
160
160
  cls.forward = forward
161
161
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: kernels
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  Requires-Python: >=3.9
@@ -8,7 +8,8 @@ Description-Content-Type: text/markdown
8
8
  Requires-Dist: huggingface-hub>=0.26.3
9
9
  Requires-Dist: packaging>=24.2
10
10
  Requires-Dist: tomli>=2.0.1; python_version < "3.11"
11
- Requires-Dist: torch>=2.5
11
+ Provides-Extra: torch
12
+ Requires-Dist: torch; extra == "torch"
12
13
 
13
14
  # kernels
14
15
 
@@ -1,6 +1,8 @@
1
1
  huggingface-hub>=0.26.3
2
2
  packaging>=24.2
3
- torch>=2.5
4
3
 
5
4
  [:python_version < "3.11"]
6
5
  tomli>=2.0.1
6
+
7
+ [torch]
8
+ torch
@@ -53,6 +53,24 @@ class SiluAndMulStringDevice(SiluAndMul):
53
53
  pass
54
54
 
55
55
 
56
+ def test_arg_kinds():
57
+ @use_kernel_forward_from_hub("ArgKind")
58
+ class ArgKind(nn.Module):
59
+ def forward(
60
+ self,
61
+ arg1,
62
+ arg2,
63
+ *,
64
+ kwarg1,
65
+ kwarg2=42,
66
+ ):
67
+ return (arg1, arg2, kwarg1, kwarg2)
68
+
69
+ arg_kind = ArgKind()
70
+ assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
71
+ assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
72
+
73
+
56
74
  @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
57
75
  @pytest.mark.parametrize("device", ["cuda", "cpu"])
58
76
  def test_hub_forward(cls, device):
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes