kernels-data 0.0.1__tar.gz → 0.14.0.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/Cargo.lock +5 -5
  2. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/PKG-INFO +1 -1
  3. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/Cargo.toml +1 -1
  4. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/bindings/python/Cargo.toml +1 -1
  5. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/bindings/python/kernels_data.pyi +41 -28
  6. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/bindings/python/src/lib.rs +67 -39
  7. kernels_data-0.14.0.dev1/kernels-data/bindings/python/stubtest-allowlist.txt +3 -0
  8. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/bindings/python/tests/test_kernels_data.py +43 -32
  9. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/config/mod.rs +10 -1
  10. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/config/v1.rs +1 -0
  11. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/config/v2.rs +1 -0
  12. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/config/v3.rs +3 -0
  13. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/metadata.rs +5 -1
  14. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/pyproject.toml +1 -0
  15. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/Cargo.toml +0 -0
  16. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/bindings/python/MANIFEST.in +0 -0
  17. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/config/compat.rs +0 -0
  18. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/config/deps.rs +0 -0
  19. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/config/name.rs +0 -0
  20. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/lib.rs +0 -0
  21. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/python_dependencies.json +0 -0
  22. {kernels_data-0.0.1 → kernels_data-0.14.0.dev1}/kernels-data/src/version.rs +0 -0
@@ -1047,7 +1047,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
1047
1047
 
1048
1048
  [[package]]
1049
1049
  name = "hf-kernel-builder"
1050
- version = "0.13.0-dev0"
1050
+ version = "0.14.0-dev1"
1051
1051
  dependencies = [
1052
1052
  "base32",
1053
1053
  "clap",
@@ -1526,7 +1526,7 @@ dependencies = [
1526
1526
 
1527
1527
  [[package]]
1528
1528
  name = "kernel-abi-check"
1529
- version = "0.13.0-dev0"
1529
+ version = "0.14.0-dev1"
1530
1530
  dependencies = [
1531
1531
  "clap",
1532
1532
  "color-eyre",
@@ -1541,7 +1541,7 @@ dependencies = [
1541
1541
 
1542
1542
  [[package]]
1543
1543
  name = "kernel-abi-check-python"
1544
- version = "0.13.0-dev0"
1544
+ version = "0.14.0-dev1"
1545
1545
  dependencies = [
1546
1546
  "kernel-abi-check",
1547
1547
  "object 0.36.7",
@@ -1550,7 +1550,7 @@ dependencies = [
1550
1550
 
1551
1551
  [[package]]
1552
1552
  name = "kernels-data"
1553
- version = "0.1.0"
1553
+ version = "0.14.0-dev1"
1554
1554
  dependencies = [
1555
1555
  "eyre",
1556
1556
  "itertools 0.13.0",
@@ -1564,7 +1564,7 @@ dependencies = [
1564
1564
 
1565
1565
  [[package]]
1566
1566
  name = "kernels-data-python"
1567
- version = "0.0.1"
1567
+ version = "0.14.0-dev1"
1568
1568
  dependencies = [
1569
1569
  "kernels-data",
1570
1570
  "pyo3",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kernels-data
3
- Version: 0.0.1
3
+ Version: 0.14.0.dev1
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Summary: Kernels data structures (Python bindings)
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "kernels-data"
3
- version = "0.1.0"
3
+ version = "0.14.0-dev1"
4
4
  edition = "2024"
5
5
  description = "Kernels data structures"
6
6
  homepage = "https://github.com/huggingface/kernels"
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "kernels-data-python"
3
- version = "0.0.1"
3
+ version = "0.14.0-dev1"
4
4
  edition = "2024"
5
5
  description = "Kernels data structures (Python bindings)"
6
6
  homepage = "https://github.com/huggingface/kernels"
@@ -1,19 +1,23 @@
1
1
  """Type stubs for kernels_data module."""
2
2
 
3
- from enum import Enum
4
- from typing import Optional
5
3
  import os
4
+ from enum import Enum
5
+ from typing import Optional, final
6
+
7
+ __all__ = ["Backend", "BackendInfo", "KernelName", "Metadata", "Version", "__version__"]
6
8
 
7
9
  __version__: str
8
10
 
11
+ @final
9
12
  class Backend(Enum):
10
13
  """Kernel backend (hardware target)."""
11
14
 
15
+ CANN = "CANN"
12
16
  CPU = "CPU"
13
17
  CUDA = "CUDA"
14
- METAL = "METAL"
15
- NEURON = "NEURON"
16
- ROCM = "ROCM"
18
+ Metal = "Metal"
19
+ Neuron = "Neuron"
20
+ ROCm = "ROCm"
17
21
  XPU = "XPU"
18
22
 
19
23
  @staticmethod
@@ -21,8 +25,8 @@ class Backend(Enum):
21
25
  """Parse a backend name.
22
26
 
23
27
  Args:
24
- s: One of ``"cpu"``, ``"cuda"``, ``"metal"``, ``"neuron"``,
25
- ``"rocm"``, ``"xpu"``.
28
+ s: One of ``"cann"``, ``"cpu"``, ``"cuda"``, ``"metal"``,
29
+ ``"neuron"``, ``"rocm"``, ``"xpu"``.
26
30
 
27
31
  Raises:
28
32
  ValueError: If the backend name is unknown.
@@ -32,6 +36,23 @@ class Backend(Enum):
32
36
  def __str__(self) -> str: ...
33
37
  def __repr__(self) -> str: ...
34
38
 
39
+ @final
40
+ class BackendInfo:
41
+ """Backend information."""
42
+
43
+ @property
44
+ def backend_type(self) -> Backend:
45
+ """Return the backend type."""
46
+ ...
47
+
48
+ @property
49
+ def archs(self) -> Optional[list[str]]:
50
+ """Optional list of target architectures."""
51
+ ...
52
+
53
+ def __repr__(self) -> str: ...
54
+
55
+ @final
35
56
  class Version:
36
57
  """A dotted numeric version (e.g. ``12.8.0``).
37
58
 
@@ -49,17 +70,18 @@ class Version:
49
70
 
50
71
  def __str__(self) -> str: ...
51
72
  def __repr__(self) -> str: ...
52
- def __eq__(self, other: object) -> bool: ...
53
- def __lt__(self, other: "Version") -> bool: ...
54
- def __le__(self, other: "Version") -> bool: ...
55
- def __gt__(self, other: "Version") -> bool: ...
56
- def __ge__(self, other: "Version") -> bool: ...
73
+ def __eq__(self, value: object, /) -> bool: ...
74
+ def __lt__(self, value: "Version", /) -> bool: ...
75
+ def __le__(self, value: "Version", /) -> bool: ...
76
+ def __gt__(self, value: "Version", /) -> bool: ...
77
+ def __ge__(self, value: "Version", /) -> bool: ...
57
78
  def __hash__(self) -> int: ...
58
79
 
80
+ @final
59
81
  class KernelName:
60
82
  """A validated kernel name matching ``^[a-z][-a-z0-9]*[a-z0-9]$``."""
61
83
 
62
- def __init__(self, name: str) -> None:
84
+ def __new__(cls, name: str) -> "KernelName":
63
85
  """Create a new ``KernelName``.
64
86
 
65
87
  Raises:
@@ -74,20 +96,19 @@ class KernelName:
74
96
 
75
97
  def __str__(self) -> str: ...
76
98
  def __repr__(self) -> str: ...
77
- def __eq__(self, other: object) -> bool: ...
99
+ def __eq__(self, value: object, /) -> bool: ...
78
100
  def __hash__(self) -> int: ...
79
101
 
102
+ @final
80
103
  class Metadata:
81
104
  """Parsed ``metadata.json`` for a kernel build variant."""
82
105
 
83
106
  @staticmethod
84
- def load_from_variant(variant_path: os.PathLike[str] | str) -> Optional["Metadata"]:
85
- """Load ``metadata.json`` from a build variant directory.
86
-
87
- Returns ``None`` if the file does not exist in ``variant_path``.
107
+ def load(metadata_path: os.PathLike[str] | str) -> "Metadata":
108
+ """Parse ``metadata.json`` at the given path.
88
109
 
89
110
  Raises:
90
- ValueError: If the file exists but cannot be parsed.
111
+ ValueError: On any I/O or parse error.
91
112
  """
92
113
  ...
93
114
 
@@ -100,13 +121,5 @@ class Metadata:
100
121
  @property
101
122
  def python_depends(self) -> list[str]: ...
102
123
  @property
103
- def backend(self) -> Backend: ...
124
+ def backend(self) -> BackendInfo: ...
104
125
  def __repr__(self) -> str: ...
105
-
106
- def parse_metadata(path: os.PathLike[str] | str) -> Metadata:
107
- """Parse a kernel ``metadata.json`` file.
108
-
109
- Raises:
110
- ValueError: On any I/O or parse error.
111
- """
112
- ...
@@ -2,7 +2,7 @@ use std::path::PathBuf;
2
2
  use std::str::FromStr;
3
3
 
4
4
  use kernels_data::config::{Backend, KernelName};
5
- use kernels_data::metadata::{Metadata, parse_metadata};
5
+ use kernels_data::metadata::{BackendInfo, Metadata, parse_metadata};
6
6
  use kernels_data::version::Version;
7
7
  use pyo3::Bound as PyBound;
8
8
  use pyo3::exceptions::PyValueError;
@@ -72,15 +72,17 @@ impl PyKernelName {
72
72
  #[pyclass(name = "Backend", eq, frozen, hash)]
73
73
  #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
74
74
  enum PyBackend {
75
+ #[pyo3(name = "CANN")]
76
+ Cann,
75
77
  #[pyo3(name = "CPU")]
76
78
  Cpu,
77
79
  #[pyo3(name = "CUDA")]
78
80
  Cuda,
79
- #[pyo3(name = "METAL")]
81
+ #[pyo3(name = "Metal")]
80
82
  Metal,
81
- #[pyo3(name = "NEURON")]
83
+ #[pyo3(name = "Neuron")]
82
84
  Neuron,
83
- #[pyo3(name = "ROCM")]
85
+ #[pyo3(name = "ROCm")]
84
86
  Rocm,
85
87
  #[pyo3(name = "XPU")]
86
88
  Xpu,
@@ -89,6 +91,7 @@ enum PyBackend {
89
91
  impl From<Backend> for PyBackend {
90
92
  fn from(b: Backend) -> Self {
91
93
  match b {
94
+ Backend::Cann => PyBackend::Cann,
92
95
  Backend::Cpu => PyBackend::Cpu,
93
96
  Backend::Cuda => PyBackend::Cuda,
94
97
  Backend::Metal => PyBackend::Metal,
@@ -102,6 +105,7 @@ impl From<Backend> for PyBackend {
102
105
  impl From<PyBackend> for Backend {
103
106
  fn from(b: PyBackend) -> Self {
104
107
  match b {
108
+ PyBackend::Cann => Backend::Cann,
105
109
  PyBackend::Cpu => Backend::Cpu,
106
110
  PyBackend::Cuda => Backend::Cuda,
107
111
  PyBackend::Metal => Backend::Metal,
@@ -114,8 +118,8 @@ impl From<PyBackend> for Backend {
114
118
 
115
119
  #[pymethods]
116
120
  impl PyBackend {
117
- /// Parse a backend name (`"cpu"`, `"cuda"`, `"metal"`, `"neuron"`,
118
- /// `"rocm"`, `"xpu"`).
121
+ /// Parse a backend name (`"cann"`, `"cpu"`, `"cuda"`, `"metal"`,
122
+ /// `"neuron"`, `"rocm"`, `"xpu"`).
119
123
  #[staticmethod]
120
124
  #[pyo3(name = "from_str")]
121
125
  fn py_from_str(s: &str) -> PyResult<Self> {
@@ -130,17 +134,56 @@ impl PyBackend {
130
134
 
131
135
  fn __repr__(&self) -> String {
132
136
  let variant = match self {
137
+ PyBackend::Cann => "CANN",
133
138
  PyBackend::Cpu => "CPU",
134
139
  PyBackend::Cuda => "CUDA",
135
- PyBackend::Metal => "METAL",
136
- PyBackend::Neuron => "NEURON",
137
- PyBackend::Rocm => "ROCM",
140
+ PyBackend::Metal => "Metal",
141
+ PyBackend::Neuron => "Neuron",
142
+ PyBackend::Rocm => "ROCm",
138
143
  PyBackend::Xpu => "XPU",
139
144
  };
140
145
  format!("Backend.{variant}")
141
146
  }
142
147
  }
143
148
 
149
+ /// Backend information
150
+ #[pyclass(name = "BackendInfo", frozen)]
151
+ #[derive(Clone, Debug)]
152
+ struct PyBackendInfo {
153
+ backend_type: PyBackend,
154
+ archs: Option<Vec<String>>,
155
+ }
156
+
157
+ impl From<BackendInfo> for PyBackendInfo {
158
+ fn from(backend_info: BackendInfo) -> Self {
159
+ Self {
160
+ backend_type: backend_info.backend_type.into(),
161
+ archs: backend_info.archs,
162
+ }
163
+ }
164
+ }
165
+
166
+ #[pymethods]
167
+ impl PyBackendInfo {
168
+ fn __repr__(&self) -> String {
169
+ format!(
170
+ "BackendInfo(backend_type={}, archs={:?})",
171
+ self.backend_type.__repr__(),
172
+ self.archs
173
+ )
174
+ }
175
+
176
+ #[getter]
177
+ fn backend_type(&self) -> PyBackend {
178
+ self.backend_type
179
+ }
180
+
181
+ #[getter]
182
+ fn archs(&self) -> Option<&[String]> {
183
+ self.archs.as_deref()
184
+ }
185
+ }
186
+
144
187
  /// Parsed `metadata.json` for a kernel build variant.
145
188
  #[pyclass(name = "Metadata", frozen)]
146
189
  #[derive(Clone, Debug)]
@@ -149,7 +192,7 @@ struct PyMetadata {
149
192
  license: Option<String>,
150
193
  upstream: Option<String>,
151
194
  python_depends: Vec<String>,
152
- backend: PyBackend,
195
+ backend: PyBackendInfo,
153
196
  }
154
197
 
155
198
  impl From<Metadata> for PyMetadata {
@@ -159,25 +202,20 @@ impl From<Metadata> for PyMetadata {
159
202
  license: m.license,
160
203
  upstream: m.upstream.map(|u| u.to_string()),
161
204
  python_depends: m.python_depends,
162
- backend: m.backend.backend_type.into(),
205
+ backend: m.backend.into(),
163
206
  }
164
207
  }
165
208
  }
166
209
 
167
210
  #[pymethods]
168
211
  impl PyMetadata {
169
- /// Load `metadata.json` from a build variant directory.
212
+ /// Parse `metadata.json` at the given path.
170
213
  ///
171
- /// Returns `None` if the file does not exist in the given directory.
172
- /// Raises `ValueError` if the file exists but cannot be parsed.
214
+ /// Raises `ValueError` on any I/O or parse error.
173
215
  #[staticmethod]
174
- fn load_from_variant(variant_path: PathBuf) -> PyResult<Option<Self>> {
175
- let metadata_path = variant_path.join("metadata.json");
176
- if !metadata_path.exists() {
177
- return Ok(None);
178
- }
216
+ fn load(metadata_path: PathBuf) -> PyResult<Self> {
179
217
  parse_metadata(&metadata_path)
180
- .map(|m| Some(m.into()))
218
+ .map(Into::into)
181
219
  .map_err(|err| PyValueError::new_err(format!("{err:#}")))
182
220
  }
183
221
 
@@ -187,23 +225,23 @@ impl PyMetadata {
187
225
  }
188
226
 
189
227
  #[getter]
190
- fn license(&self) -> Option<&str> {
191
- self.license.as_deref()
228
+ fn license(&self) -> Option<&String> {
229
+ self.license.as_ref()
192
230
  }
193
231
 
194
232
  #[getter]
195
- fn upstream(&self) -> Option<&str> {
196
- self.upstream.as_deref()
233
+ fn upstream(&self) -> Option<&String> {
234
+ self.upstream.as_ref()
197
235
  }
198
236
 
199
237
  #[getter]
200
- fn python_depends(&self) -> Vec<String> {
201
- self.python_depends.clone()
238
+ fn python_depends(&self) -> &[String] {
239
+ &self.python_depends
202
240
  }
203
241
 
204
242
  #[getter]
205
- fn backend(&self) -> PyBackend {
206
- self.backend
243
+ fn backend(&self) -> PyBackendInfo {
244
+ self.backend.clone()
207
245
  }
208
246
 
209
247
  fn __repr__(&self) -> String {
@@ -218,23 +256,13 @@ impl PyMetadata {
218
256
  }
219
257
  }
220
258
 
221
- /// Parse a kernel `metadata.json` file.
222
- ///
223
- /// Raises `ValueError` on any I/O or parse error.
224
- #[pyfunction(name = "parse_metadata")]
225
- fn py_parse_metadata(path: PathBuf) -> PyResult<PyMetadata> {
226
- parse_metadata(&path)
227
- .map(Into::into)
228
- .map_err(|err| PyValueError::new_err(format!("{err:#}")))
229
- }
230
-
231
259
  #[pyo3::pymodule(name = "kernels_data")]
232
260
  fn kernels_data_py(m: &PyBound<'_, PyModule>) -> PyResult<()> {
233
261
  m.add_class::<PyBackend>()?;
262
+ m.add_class::<PyBackendInfo>()?;
234
263
  m.add_class::<PyKernelName>()?;
235
264
  m.add_class::<PyMetadata>()?;
236
265
  m.add_class::<PyVersion>()?;
237
- m.add_function(wrap_pyfunction!(py_parse_metadata, m)?)?;
238
266
 
239
267
  m.add("__version__", env!("CARGO_PKG_VERSION"))?;
240
268
  Ok(())
@@ -0,0 +1,3 @@
1
+ # The `.so` submodule `kernels_data.kernels_data` is a pyo3/maturin
2
+ # implementation detail; end users import from the `kernels_data` package.
3
+ kernels_data\.kernels_data
@@ -2,12 +2,11 @@ import json
2
2
 
3
3
  import pytest
4
4
 
5
- from kernels_data import Backend, KernelName, Metadata, Version, parse_metadata
5
+ from kernels_data import Backend, KernelName, Metadata, Version
6
6
 
7
7
 
8
- def _write_metadata(variant_dir, **fields):
9
- variant_dir.mkdir(parents=True, exist_ok=True)
10
- path = variant_dir / "metadata.json"
8
+ def _write_metadata(path, **fields):
9
+ path.parent.mkdir(parents=True, exist_ok=True)
11
10
  path.write_text(json.dumps(fields))
12
11
  return path
13
12
 
@@ -72,7 +71,21 @@ def test_backend_unknown():
72
71
  Backend.from_str("tpu")
73
72
 
74
73
 
75
- def test_metadata_parse_full(tmp_path):
74
+ def test_backend_all_variants_and_casing():
75
+ assert str(Backend.Metal) == "metal"
76
+ assert repr(Backend.Metal) == "Backend.Metal"
77
+ assert str(Backend.Neuron) == "neuron"
78
+ assert repr(Backend.Neuron) == "Backend.Neuron"
79
+ assert str(Backend.ROCm) == "rocm"
80
+ assert repr(Backend.ROCm) == "Backend.ROCm"
81
+ assert repr(Backend.XPU) == "Backend.XPU"
82
+ assert repr(Backend.CANN) == "Backend.CANN"
83
+ assert Backend.from_str("cann") == Backend.CANN
84
+ assert Backend.from_str("ROCM") == Backend.ROCm
85
+ assert Backend.from_str("metal") == Backend.Metal
86
+
87
+
88
+ def test_metadata_load_full(tmp_path):
76
89
  path = tmp_path / "metadata.json"
77
90
  path.write_text(
78
91
  json.dumps(
@@ -81,32 +94,37 @@ def test_metadata_parse_full(tmp_path):
81
94
  "license": "Apache-2.0",
82
95
  "upstream": "https://github.com/example/kernel",
83
96
  "python-depends": ["torch"],
84
- "backend": {"type": "cuda"},
97
+ "backend": {"type": "cuda", "archs": ["9.0", "10.0"]},
85
98
  }
86
99
  )
87
100
  )
88
- m = parse_metadata(path)
101
+ m = Metadata.load(path)
89
102
  assert m.version == 1
90
103
  assert m.license == "Apache-2.0"
91
104
  assert m.upstream == "https://github.com/example/kernel"
92
105
  assert m.python_depends == ["torch"]
93
- assert m.backend == Backend.CUDA
106
+ assert m.backend.backend_type == Backend.CUDA
107
+ assert m.backend.archs == ["9.0", "10.0"]
94
108
 
95
109
 
96
- def test_metadata_parse_minimal(tmp_path):
110
+ def test_metadata_load_minimal(tmp_path):
97
111
  path = tmp_path / "metadata.json"
98
- path.write_text(
99
- json.dumps({"python-depends": [], "backend": {"type": "cpu"}})
100
- )
101
- m = parse_metadata(path)
112
+ path.write_text(json.dumps({"python-depends": [], "backend": {"type": "cpu"}}))
113
+ m = Metadata.load(path)
102
114
  assert m.version is None
103
115
  assert m.license is None
104
116
  assert m.upstream is None
105
117
  assert m.python_depends == []
106
- assert m.backend == Backend.CPU
118
+ assert m.backend.backend_type == Backend.CPU
119
+
120
+
121
+ def test_metadata_load_cann(tmp_path):
122
+ path = tmp_path / "metadata.json"
123
+ path.write_text(json.dumps({"python-depends": [], "backend": {"type": "cann"}}))
124
+ assert Metadata.load(path).backend.backend_type == Backend.CANN
107
125
 
108
126
 
109
- def test_metadata_parse_unknown_field_rejected(tmp_path):
127
+ def test_metadata_load_unknown_field_accepted(tmp_path):
110
128
  path = tmp_path / "metadata.json"
111
129
  path.write_text(
112
130
  json.dumps(
@@ -117,32 +135,25 @@ def test_metadata_parse_unknown_field_rejected(tmp_path):
117
135
  }
118
136
  )
119
137
  )
120
- with pytest.raises(ValueError):
121
- parse_metadata(path)
138
+ Metadata.load(path)
122
139
 
123
140
 
124
- def test_metadata_parse_malformed(tmp_path):
141
+ def test_metadata_load_malformed(tmp_path):
125
142
  path = tmp_path / "metadata.json"
126
143
  path.write_text("{not json")
127
144
  with pytest.raises(ValueError):
128
- parse_metadata(path)
145
+ Metadata.load(path)
129
146
 
130
147
 
131
- def test_metadata_load_from_variant(tmp_path):
132
- _write_metadata(
133
- tmp_path / "variant",
148
+ def test_metadata_load(tmp_path):
149
+ path = _write_metadata(
150
+ tmp_path / "variant" / "metadata.json",
134
151
  **{"python-depends": ["torch"], "backend": {"type": "cuda"}},
135
152
  )
136
- m = Metadata.load_from_variant(tmp_path / "variant")
137
- assert m is not None
138
- assert m.backend == Backend.CUDA
139
-
140
-
141
- def test_metadata_load_from_variant_missing(tmp_path):
142
- (tmp_path / "empty-variant").mkdir()
143
- assert Metadata.load_from_variant(tmp_path / "empty-variant") is None
153
+ m = Metadata.load(path)
154
+ assert m.backend.backend_type == Backend.CUDA
144
155
 
145
156
 
146
- def test_metadata_parse_missing_file(tmp_path):
157
+ def test_metadata_load_missing_file(tmp_path):
147
158
  with pytest.raises(ValueError):
148
- parse_metadata(tmp_path / "does-not-exist.json")
159
+ Metadata.load(tmp_path / "does-not-exist.json")
@@ -53,6 +53,10 @@ impl Build {
53
53
  self.kernels.is_empty()
54
54
  }
55
55
 
56
+ pub fn branch(&self) -> Option<&str> {
57
+ self.general.hub.as_ref().and_then(|h| h.branch.as_deref())
58
+ }
59
+
56
60
  pub fn repo_id(&self) -> Option<&str> {
57
61
  self.general.hub.as_ref().and_then(|h| h.repo_id.as_deref())
58
62
  }
@@ -286,6 +290,7 @@ impl Kernel {
286
290
  #[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
287
291
  #[serde(deny_unknown_fields, rename_all = "kebab-case")]
288
292
  pub enum Backend {
293
+ Cann,
289
294
  Cpu,
290
295
  Cuda,
291
296
  Metal,
@@ -295,8 +300,9 @@ pub enum Backend {
295
300
  }
296
301
 
297
302
  impl Backend {
298
- pub const fn all() -> [Backend; 6] {
303
+ pub const fn all() -> [Backend; 7] {
299
304
  [
305
+ Backend::Cann,
300
306
  Backend::Cpu,
301
307
  Backend::Cuda,
302
308
  Backend::Metal,
@@ -308,6 +314,7 @@ impl Backend {
308
314
 
309
315
  pub const fn as_str(&self) -> &'static str {
310
316
  match self {
317
+ Backend::Cann => "cann",
311
318
  Backend::Cpu => "cpu",
312
319
  Backend::Cuda => "cuda",
313
320
  Backend::Metal => "metal",
@@ -321,6 +328,7 @@ impl Backend {
321
328
  impl Display for Backend {
322
329
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323
330
  match self {
331
+ Backend::Cann => write!(f, "cann"),
324
332
  Backend::Cpu => write!(f, "cpu"),
325
333
  Backend::Cuda => write!(f, "cuda"),
326
334
  Backend::Metal => write!(f, "metal"),
@@ -336,6 +344,7 @@ impl FromStr for Backend {
336
344
 
337
345
  fn from_str(s: &str) -> Result<Self, Self::Err> {
338
346
  match s.to_lowercase().as_str() {
347
+ "cann" => Ok(Backend::Cann),
339
348
  "cpu" => Ok(Backend::Cpu),
340
349
  "cuda" => Ok(Backend::Cuda),
341
350
  "metal" => Ok(Backend::Metal),
@@ -83,6 +83,7 @@ impl TryFrom<Build> for super::Build {
83
83
 
84
84
  let backends = if universal {
85
85
  vec![
86
+ Backend::Cann,
86
87
  Backend::Cpu,
87
88
  Backend::Cuda,
88
89
  Backend::Metal,
@@ -129,6 +129,7 @@ impl TryFrom<Build> for super::Build {
129
129
 
130
130
  let backends = if build.general.universal {
131
131
  vec![
132
+ Backend::Cann,
132
133
  Backend::Cpu,
133
134
  Backend::Cuda,
134
135
  Backend::Metal,
@@ -149,6 +149,7 @@ pub enum Kernel {
149
149
  #[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
150
150
  #[serde(deny_unknown_fields, rename_all = "kebab-case")]
151
151
  pub enum Backend {
152
+ Cann,
152
153
  Cpu,
153
154
  Cuda,
154
155
  Metal,
@@ -256,6 +257,7 @@ impl From<TvmFfi> for super::TvmFfi {
256
257
  impl From<Backend> for super::Backend {
257
258
  fn from(backend: Backend) -> Self {
258
259
  match backend {
260
+ Backend::Cann => super::Backend::Cann,
259
261
  Backend::Cpu => super::Backend::Cpu,
260
262
  Backend::Cuda => super::Backend::Cuda,
261
263
  Backend::Metal => super::Backend::Metal,
@@ -437,6 +439,7 @@ impl From<super::TvmFfi> for TvmFfi {
437
439
  impl From<super::Backend> for Backend {
438
440
  fn from(backend: super::Backend) -> Self {
439
441
  match backend {
442
+ super::Backend::Cann => Backend::Cann,
440
443
  super::Backend::Cpu => Backend::Cpu,
441
444
  super::Backend::Cuda => Backend::Cuda,
442
445
  super::Backend::Metal => Backend::Metal,
@@ -10,11 +10,15 @@ use crate::config::Backend;
10
10
  pub struct BackendInfo {
11
11
  #[serde(rename = "type")]
12
12
  pub backend_type: Backend,
13
+ #[serde(skip_serializing_if = "Option::is_none")]
14
+ pub archs: Option<Vec<String>>,
13
15
  }
14
16
 
17
+ /// Kernel metadata.
15
18
  #[derive(Debug, Deserialize, Serialize)]
16
- #[serde(deny_unknown_fields, rename_all = "kebab-case")]
19
+ #[serde(rename_all = "kebab-case")]
17
20
  pub struct Metadata {
21
+ pub id: Option<String>,
18
22
  #[serde(skip_serializing_if = "Option::is_none")]
19
23
  pub version: Option<usize>,
20
24
  #[serde(skip_serializing_if = "Option::is_none")]
@@ -18,6 +18,7 @@ build-backend = "maturin"
18
18
  [dependency-groups]
19
19
  dev = [
20
20
  "pytest>=8",
21
+ "mypy>=1.11",
21
22
  ]
22
23
 
23
24
  [tool.maturin]