tirx-kernels 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,54 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ """TIRX kernel library.
18
+
19
+ Private or experimental kernels can be layered on top of the released package by
20
+ setting ``TIRX_KERNELS_OVERLAY_PATHS`` to one or more ``tirx_kernels`` package
21
+ directories, separated by ``os.pathsep``.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import os
27
+ from pathlib import Path
28
+
29
+ _OVERLAY_ENV = "TIRX_KERNELS_OVERLAY_PATHS"
30
+
31
+
32
+ def _iter_overlay_paths() -> list[Path]:
33
+ paths = []
34
+ raw = os.environ.get(_OVERLAY_ENV, "")
35
+ for item in raw.split(os.pathsep):
36
+ item = item.strip()
37
+ if not item:
38
+ continue
39
+ path = Path(item).expanduser()
40
+ if path.is_dir():
41
+ paths.append(path.resolve())
42
+ return paths
43
+
44
+
45
+ def _append_overlay_paths() -> None:
46
+ seen = {str(Path(item).resolve()) for item in __path__}
47
+ for path in _iter_overlay_paths():
48
+ text = str(path)
49
+ if text not in seen:
50
+ __path__.append(text)
51
+ seen.add(text)
52
+
53
+
54
+ _append_overlay_paths()
@@ -0,0 +1,78 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ """Standard kernel interface protocol.
18
+
19
+ Every kernel module under ``kernels/<category>/`` that wants to be
20
+ discoverable by the registry must expose:
21
+
22
+ Module-level constants
23
+ ----------------------
24
+ KERNEL_META : dict
25
+ Required keys:
26
+ - "name" (str): unique kernel name used by CLI (e.g. "rmsnorm")
27
+ - "category" (str): one of gemm, gemm_comm, attention, normalization, activation, ssm, loss, moe
28
+ - "compute_capability" (int): minimum SM version (e.g. 10 for sm100a)
29
+
30
+ CONFIGS : list[dict]
31
+ Each dict has a "label" key (str) plus arbitrary kernel-specific
32
+ parameters. The union of tir test parametrize configs and bench-ci
33
+ benchmark configs.
34
+
35
+ BENCH_CONFIGS : list[dict] (optional)
36
+ Benchmark-only configs. If present, ``python -m tirx_kernels.bench``
37
+ uses these instead of ``CONFIGS`` so expensive benchmark sweeps do not
38
+ automatically become pytest correctness cases.
39
+
40
+ Functions
41
+ ---------
42
+ get_kernel(**cfg) -> tvm.tirx.PrimFunc | list[tvm.tirx.PrimFunc]
43
+ Return the TIR PrimFunc(s) for this kernel. Multi-kernel workloads
44
+ (e.g. split-k GEMM with a separate reduce kernel) return a list.
45
+
46
+ prepare_data(**cfg) -> dict[str, Any]
47
+ Prepare input/output tensors. Returns a dict mapping argument names
48
+ to tensors (torch.Tensor or numpy.ndarray).
49
+
50
+ check_correctness(outputs: dict, **cfg) -> None
51
+ Validate kernel outputs against a reference.
52
+ Raise AssertionError on mismatch.
53
+
54
+ get_baselines(**cfg) -> dict[str, Callable] (optional)
55
+ Return {name: callable} for baseline implementations used in
56
+ benchmarking (e.g. cublas, flashinfer).
57
+ """
58
+
59
+ from __future__ import annotations
60
+
61
+ from typing import Any, Protocol, runtime_checkable
62
+
63
+
64
+ @runtime_checkable
65
+ class KernelModule(Protocol):
66
+ """Structural type that a kernel module must satisfy."""
67
+
68
+ KERNEL_META: dict[str, Any]
69
+ CONFIGS: list[dict[str, Any]]
70
+
71
+ @staticmethod
72
+ def get_kernel(**kwargs: Any) -> Any: ...
73
+
74
+ @staticmethod
75
+ def prepare_data(**kwargs: Any) -> dict[str, Any]: ...
76
+
77
+ @staticmethod
78
+ def check_correctness(outputs: dict[str, Any], **kwargs: Any) -> None: ...
File without changes