kernels 0.4.4__tar.gz → 0.5.0.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: kernels
3
- Version: 0.4.4
3
+ Version: 0.5.0.dev0
4
4
  Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  License: Apache-2.0
@@ -12,6 +12,7 @@ Requires-Dist: packaging>=20.0
12
12
  Requires-Dist: tomli>=2.0; python_version < "3.11"
13
13
  Provides-Extra: torch
14
14
  Requires-Dist: torch; extra == "torch"
15
+ Dynamic: license-file
15
16
 
16
17
  # kernels
17
18
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kernels"
3
- version = "0.4.4"
3
+ version = "0.5.0.dev0"
4
4
  description = "Download compute kernels"
5
5
  authors = [
6
6
  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
@@ -138,6 +138,8 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
138
138
  return fallback_forward(self, x, *args, **kwargs)
139
139
 
140
140
  needs_backward = self.training
141
+ is_compiling = _is_torchdynamo_compiling()
142
+
141
143
  kernel = _KERNEL_MAPPING.get().get(layer_name)
142
144
  if kernel is None:
143
145
  warnings.warn(
@@ -165,7 +167,14 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
165
167
  # Short-circuit if we already loaded the layer.
166
168
  layer = cached_layer.get(repo, None)
167
169
  if layer is not None:
168
- if needs_backward and not getattr(layer, "has_backward", True):
170
+ # Switch to fallback when the layer does not support:
171
+ # compilation/compile when needed.
172
+ # backward when needed
173
+ needs_fallback = needs_backward and not getattr(layer, "has_backward", True)
174
+ needs_fallback |= is_compiling and not getattr(
175
+ layer, "can_torch_compile", False
176
+ )
177
+ if needs_fallback:
169
178
  return fallback_forward(self, x, *args, **kwargs)
170
179
  return layer.forward(self, x, *args, **kwargs)
171
180
 
@@ -185,8 +194,15 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
185
194
 
186
195
  cached_layer[repo] = layer
187
196
 
188
- if needs_backward and not getattr(layer, "has_backward", True):
197
+ # Switch to fallback when the layer does not support
198
+ # compilation/compile when needed.
199
+ needs_fallback = needs_backward and not getattr(layer, "has_backward", True)
200
+ needs_fallback |= is_compiling and not getattr(
201
+ layer, "can_torch_compile", False
202
+ )
203
+ if needs_fallback:
189
204
  return fallback_forward(self, x, *args, **kwargs)
205
+
190
206
  return layer.forward(self, x, *args, **kwargs)
191
207
 
192
208
  cls.forward = forward
@@ -245,7 +261,8 @@ def _validate_layer(*, check_cls, cls):
245
261
  torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
246
262
  cls_members = {name for name, _ in inspect.getmembers(cls)}
247
263
  difference = cls_members - torch_module_members
248
- if difference != set() and difference != {"has_backward"}:
264
+ # verify if : difference {"can_torch_compile", "has_backward"}
265
+ if not difference <= {"can_torch_compile", "has_backward"}:
249
266
  raise TypeError("Layer must not contain additional members.")
250
267
 
251
268
  # Check whether the forward signatures are similar.
@@ -262,3 +279,19 @@ def _validate_layer(*, check_cls, cls):
262
279
  raise TypeError(
263
280
  f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
264
281
  )
282
+
283
+
284
+ def _is_torchdynamo_compiling():
285
+ # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622)
286
+ # hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3)
287
+ try:
288
+ import torch
289
+
290
+ return torch.compiler.is_compiling()
291
+ except Exception:
292
+ try:
293
+ import torch._dynamo as dynamo # noqa: F401
294
+
295
+ return dynamo.is_compiling()
296
+ except Exception:
297
+ return False
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: kernels
3
- Version: 0.4.4
3
+ Version: 0.5.0.dev0
4
4
  Summary: Download compute kernels
5
5
  Author-email: OlivierDehaene <olivier@huggingface.co>, Daniel de Kok <daniel@huggingface.co>, David Holtz <david@huggingface.co>, Nicolas Patry <nicolas@huggingface.co>
6
6
  License: Apache-2.0
@@ -12,6 +12,7 @@ Requires-Dist: packaging>=20.0
12
12
  Requires-Dist: tomli>=2.0; python_version < "3.11"
13
13
  Provides-Extra: torch
14
14
  Requires-Dist: torch; extra == "torch"
15
+ Dynamic: license-file
15
16
 
16
17
  # kernels
17
18
 
@@ -19,6 +19,12 @@ kernel_layer_mapping = {
19
19
  revision="layers",
20
20
  )
21
21
  },
22
+ "SiluAndMulNoCompile": {
23
+ "cuda": LayerRepository(
24
+ repo_id="kernels-test/op-without-fake-test",
25
+ layer_name="SiluAndMul",
26
+ )
27
+ },
22
28
  "SiluAndMulStringDevice": {
23
29
  "cuda": LayerRepository(
24
30
  repo_id="kernels-community/activation",
@@ -43,6 +49,11 @@ class SiluAndMul(nn.Module):
43
49
  return F.silu(input[..., :d]) * input[..., d:]
44
50
 
45
51
 
52
+ @use_kernel_forward_from_hub("SiluAndMulNoCompile")
53
+ class SiluAndMulNoCompileKernel(SiluAndMul):
54
+ pass
55
+
56
+
46
57
  @use_kernel_forward_from_hub("SiluAndMul")
47
58
  class SiluAndMulWithKernel(SiluAndMul):
48
59
  pass
@@ -101,8 +112,29 @@ def test_layer_fallback_works():
101
112
  SiluAndMulWithKernelFallback()
102
113
 
103
114
 
115
+ @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
116
+ @pytest.mark.parametrize("device", ["cuda", "cpu"])
117
+ def test_torch_compile_layer(cls, device):
118
+ silu_and_mul = SiluAndMul()
119
+
120
+ X = torch.randn((32, 64), dtype=torch.float32, device=device)
121
+ Y = silu_and_mul(X)
122
+
123
+ silu_and_mul_with_kernel = cls()
124
+ silu_and_mul_with_kernel.eval()
125
+ silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel)
126
+
127
+ Y_compiled = silu_and_mul_compiled(X)
128
+
129
+ torch.testing.assert_close(Y_compiled, Y)
130
+
131
+
104
132
  def test_mapping_contexts():
105
- assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
133
+ assert set(_KERNEL_MAPPING.get().keys()) == {
134
+ "SiluAndMul",
135
+ "SiluAndMulStringDevice",
136
+ "SiluAndMulNoCompile",
137
+ }
106
138
 
107
139
  extra_mapping1 = {
108
140
  "TestKernel": {
@@ -118,6 +150,7 @@ def test_mapping_contexts():
118
150
  assert set(_KERNEL_MAPPING.get().keys()) == {
119
151
  "SiluAndMul",
120
152
  "SiluAndMulStringDevice",
153
+ "SiluAndMulNoCompile",
121
154
  "TestKernel",
122
155
  }
123
156
 
@@ -135,6 +168,7 @@ def test_mapping_contexts():
135
168
  assert set(_KERNEL_MAPPING.get().keys()) == {
136
169
  "SiluAndMul",
137
170
  "SiluAndMulStringDevice",
171
+ "SiluAndMulNoCompile",
138
172
  "TestKernel",
139
173
  }
140
174
  assert (
@@ -145,6 +179,7 @@ def test_mapping_contexts():
145
179
  assert set(_KERNEL_MAPPING.get().keys()) == {
146
180
  "SiluAndMul",
147
181
  "SiluAndMulStringDevice",
182
+ "SiluAndMulNoCompile",
148
183
  "TestKernel",
149
184
  }
150
185
  assert (
@@ -164,6 +199,7 @@ def test_mapping_contexts():
164
199
  assert set(_KERNEL_MAPPING.get().keys()) == {
165
200
  "SiluAndMul",
166
201
  "SiluAndMulStringDevice",
202
+ "SiluAndMulNoCompile",
167
203
  "TestKernel",
168
204
  }
169
205
  assert (
@@ -174,6 +210,7 @@ def test_mapping_contexts():
174
210
  assert set(_KERNEL_MAPPING.get().keys()) == {
175
211
  "SiluAndMul",
176
212
  "SiluAndMulStringDevice",
213
+ "SiluAndMulNoCompile",
177
214
  }
178
215
 
179
216
 
File without changes
File without changes
File without changes
File without changes