ninetoothed 0.5.0__tar.gz → 0.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/PKG-INFO +9 -2
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/README.md +8 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/docs/README.zh.md +8 -0
- ninetoothed-0.7.0/docs/source/_static/matmul-tiling.png +0 -0
- ninetoothed-0.7.0/docs/source/_static/vecadd-tiling.png +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/pyproject.toml +1 -1
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/src/ninetoothed/jit.py +152 -27
- ninetoothed-0.7.0/src/ninetoothed/naming.py +50 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/src/ninetoothed/symbol.py +39 -44
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/src/ninetoothed/tensor.py +40 -14
- ninetoothed-0.7.0/tests/test_addmm.py +104 -0
- ninetoothed-0.7.0/tests/test_naming.py +51 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/tests/test_softmax.py +1 -2
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/.github/workflows/pytest.yml +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/.github/workflows/ruff.yml +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/.gitignore +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/LICENSE +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/docs/source/_static/ninetoothed-logo.png +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/requirements.txt +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/src/ninetoothed/__init__.py +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/src/ninetoothed/language.py +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/src/ninetoothed/torchifier.py +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/tests/__init__.py +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/tests/skippers.py +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/tests/test_add.py +0 -0
- {ninetoothed-0.5.0 → ninetoothed-0.7.0}/tests/test_matmul.py +0 -0
@@ -1,11 +1,10 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.7.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
7
7
|
Author-email: Jiacheng Huang <huangjiacheng0709@outlook.com>
|
8
|
-
License-File: LICENSE
|
9
8
|
Classifier: License :: OSI Approved :: Apache Software License
|
10
9
|
Classifier: Operating System :: OS Independent
|
11
10
|
Classifier: Programming Language :: Python :: 3
|
@@ -51,6 +50,8 @@ def add_kernel(
|
|
51
50
|
|
52
51
|
In this code, we first define `BLOCK_SIZE`, which is a `Symbol`. You can think of `"BLOCK_SIZE"` as its name. We see that `meta` is set to `True`, indicating to the compiler that it is a meta-parameter and its value can be determined by the compiler. The `Tensor(1)` constructs a one-dimensional tensor (vector), and `Tensor(1).tile((BLOCK_SIZE,))` means we want to create a vector and divide it into blocks of size `BLOCK_SIZE`. Suppose the size of this vector is `8192` and `BLOCK_SIZE` is `1024`, then the vector will be divided into `8` blocks, each of size `1024`.
|
53
52
|
|
53
|
+

|
54
|
+
|
54
55
|
By using type annotations, we tell the compiler that we will have three tensor parameters, which will be divided into blocks, and `x`, `y`, and `z` are these blocks. It's important to understand that `x`, `y`, and `z` are the blocks, not the tensors themselves. In the function body, `x`, `y`, and `z` are also the blocks. The rest is straightforward (only one line `z = x + y` left, haha), we add each block of `x` and `y` and store it in `z`. Since each block of the parameter tensors undergoes this operation, the addition is completed for the whole tensors as well.
|
55
56
|
|
56
57
|
### Matrix Multiplication
|
@@ -82,4 +83,10 @@ def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
|
82
83
|
|
83
84
|
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
|
84
85
|
|
86
|
+

|
87
|
+
|
85
88
|
With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of $B$, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
|
89
|
+
|
90
|
+
## License
|
91
|
+
|
92
|
+
This project is distributed under the Apache-2.0 license. See the included [LICENSE](LICENSE) file for details.
|
@@ -36,6 +36,8 @@ def add_kernel(
|
|
36
36
|
|
37
37
|
In this code, we first define `BLOCK_SIZE`, which is a `Symbol`. You can think of `"BLOCK_SIZE"` as its name. We see that `meta` is set to `True`, indicating to the compiler that it is a meta-parameter and its value can be determined by the compiler. The `Tensor(1)` constructs a one-dimensional tensor (vector), and `Tensor(1).tile((BLOCK_SIZE,))` means we want to create a vector and divide it into blocks of size `BLOCK_SIZE`. Suppose the size of this vector is `8192` and `BLOCK_SIZE` is `1024`, then the vector will be divided into `8` blocks, each of size `1024`.
|
38
38
|
|
39
|
+

|
40
|
+
|
39
41
|
By using type annotations, we tell the compiler that we will have three tensor parameters, which will be divided into blocks, and `x`, `y`, and `z` are these blocks. It's important to understand that `x`, `y`, and `z` are the blocks, not the tensors themselves. In the function body, `x`, `y`, and `z` are also the blocks. The rest is straightforward (only one line `z = x + y` left, haha), we add each block of `x` and `y` and store it in `z`. Since each block of the parameter tensors undergoes this operation, the addition is completed for the whole tensors as well.
|
40
42
|
|
41
43
|
### Matrix Multiplication
|
@@ -67,4 +69,10 @@ def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
|
67
69
|
|
68
70
|
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
|
69
71
|
|
72
|
+

|
73
|
+
|
70
74
|
With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of $B$, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
|
75
|
+
|
76
|
+
## License
|
77
|
+
|
78
|
+
This project is distributed under the Apache-2.0 license. See the included [LICENSE](LICENSE) file for details.
|
@@ -36,6 +36,8 @@ def add_kernel(
|
|
36
36
|
|
37
37
|
在这段代码当中,我们首先定义了 `BLOCK_SIZE`,它是一个 `Symbol`,我们可以把 `"BLOCK_SIZE"` 理解成它的名字。我们可以看到 `meta` 被设成了 `True`,这是在告诉编译器,它是一个元参数,可以由编译器决定它的取值。之后出现的 `Tensor(1)` 则是在构造一个一维的张量(向量),`Tensor(1).tile((BLOCK_SIZE,))` 的意思就是说,我们想要构造一个向量,并且把它分成大小为 `BLOCK_SIZE` 的块。假设这个向量的大小为 `8192`,而 `BLOCK_SIZE` 是 `1024`,那么这个向量就会被分成 `8` 块,每一块的大小都是 `1024`。
|
38
38
|
|
39
|
+

|
40
|
+
|
39
41
|
我们通过类型标注的方式,告诉了编译器,我们将会有三个参数张量,并且每个参数张量,都会被按照这样的方式分块,而 `x`、`y`、`z` 就是被分成的块。这一点很重要,我们要意识到,`x`、`y`、`z` 是被分成的块,而不是被分块的张量本身,并且函数体当中的 `x`、`y`、`z` 也都是被分成的块。剩下的就很好理解了(也就剩下 `z = x + y` 一行了,哈哈哈),我们把每一块 `x` 和 `y` 相加,放到了 `z` 中,由于参数张量被分成的每一块都被执行了这样的操作,因此即便对于整体而言,加法也被完成了。
|
40
42
|
|
41
43
|
### 矩阵乘法
|
@@ -67,4 +69,10 @@ def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
|
67
69
|
|
68
70
|
对于矩阵乘法来说,我们也有三个参数张量,但是分块的方式肯定比向量加法要复杂一些。我们将三个矩阵分别记作 $A$、$B$、$C$,其中 $A$ 和 $B$ 为输入,$C$ 为输出。其中 $C$ 的分块操作很简单,我们只需要按照行和列,将其分成大小为 `(BLOCK_SIZE_M, BLOCK_SIZE_N)` 的块即可,这样只要每个这样的块都算出了结果,整个 $C$ 也就都算出了结果。那么该如何分 $A$ 和 $B$ 呢?答案是再引入一个元参数 `BLOCK_SIZE_K`,这样我们就可以把 $A$ 分成 `(BLOCK_SIZE_M, BLOCK_SIZE_K)` 大小的块,把 $B$ 分成 `(BLOCK_SIZE_K, BLOCK_SIZE_N)` 的块。但是对于矩阵乘法,$A$ 和 $B$ 并不是块块对应,而是需要对应 $A$ 的每一行和 $B$ 的每一列,所以我们还需要继续 `tile`,把 $A$ 和 $B$ 进一步分成以行为单位和以列为单位的块。到目前为止,我们有了一堆 $A$ 的行块和 $B$ 的列块,但是对于每一个 $A$ 的行块,我们都要对应 $B$ 的每一个列块。这个时候,我们就需要进行 `expand` 了,我们把 $A$ 的行块沿着列 `expand` 成 $C$ 的列数那么多列,把 $B$ 的列块沿着行 `expand` 成 $C$ 的行数那么多行。这样,我们就成功地将 $A$、$B$、$C$ 三者都分好了块,并且对于每一个 $C$ 的块,我们都有对应好的 $A$ 的行块和 $B$ 的列块。其实我们的元操作到此为止,已经能够编写出核函数了,但是我们发现,刚才所提到的行块和列块所在的层级,是二维的,而且大小是 `(1, ...)` 和 `(..., 1)` 这样的形式。也就是说,如果不进行其它操作,那么我们访问行块和列块的方式就得是 `a[0, k]` 和 `b[k, 0]`,如果我们想要依靠 `a` 找到 `k` 的范围,那就得是 `a.shape[1]`。但是我们知道,大小为 `1` 的维度,其实完全可以被去掉,这就是为什么我们加了两行 `squeeze`,其中的 `dtype` 是数据类型的意思,在 PyTorch 中一般可以是某些整数类型或者浮点类型之类的,比如 `torch.float32`,但是由于九齿当中可以进行 `tile` 等元操作,所以 `dtype` 也可以是 `Tensor`。也就是说,在九齿当中,存在着“存储张量的张量”这样的概念。总而言之,这两行就是对最外层张量所存储的下一层的张量进行操作,把大小为 `1` 的维度去掉了,这样,我们在访问行块和列块时就可以使用 `a[k]` 和 `b[k]`,找 `k` 的范围时也可以使用 `a.shape[0]` 了。
|
69
71
|
|
72
|
+

|
73
|
+
|
70
74
|
对应好了分块,后续的部分就简单多了。在函数体当中,我们定义了一个 `accumulator`,用于累加中间结果,之后就遍历了对应好的 $A$ 的行块和 $B$ 的列块,并且把他们相乘的结果累加到了 `accumulator` 当中,最后再将 `accumulator` 放到了对应的 $C$ 的分块当中。由于参数张量被分成的每一块都被执行了这样的操作,因此即便对于整体而言,乘法也被完成了。
|
75
|
+
|
76
|
+
## 许可证
|
77
|
+
|
78
|
+
本项目采用 Apache-2.0 许可证发布。详情请参见随附的 [LICENSE](LICENSE) 文件。
|
Binary file
|
Binary file
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "ninetoothed"
|
7
|
-
version = "0.
|
7
|
+
version = "0.7.0"
|
8
8
|
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
|
9
9
|
description = "A domain-specific language based on Triton but providing higher-level abstraction."
|
10
10
|
readme = "README.md"
|
@@ -1,29 +1,41 @@
|
|
1
1
|
import ast
|
2
2
|
import collections
|
3
3
|
import functools
|
4
|
+
import importlib.util
|
4
5
|
import inspect
|
5
6
|
import itertools
|
6
7
|
import math
|
8
|
+
import subprocess
|
9
|
+
import sys
|
7
10
|
import tempfile
|
8
11
|
|
9
12
|
import triton
|
10
13
|
|
14
|
+
import ninetoothed.naming as naming
|
11
15
|
from ninetoothed.language import attribute, call
|
12
16
|
from ninetoothed.symbol import Symbol
|
13
17
|
from ninetoothed.tensor import Tensor
|
14
18
|
from ninetoothed.torchifier import Torchifier
|
15
19
|
|
16
20
|
|
17
|
-
def jit(
|
18
|
-
|
21
|
+
def jit(_func=None, *, _prettify=False):
|
22
|
+
def wrapper(func):
|
23
|
+
return JIT(func, _prettify=_prettify)()
|
24
|
+
|
25
|
+
if _func is None:
|
26
|
+
return wrapper
|
27
|
+
|
28
|
+
return wrapper(_func)
|
19
29
|
|
20
30
|
|
21
31
|
class JIT:
|
22
32
|
handles = collections.defaultdict(dict)
|
23
33
|
|
24
|
-
def __init__(self, func):
|
34
|
+
def __init__(self, func, _prettify=False):
|
25
35
|
self.func = func
|
26
36
|
|
37
|
+
self._prettify = _prettify
|
38
|
+
|
27
39
|
def __call__(self):
|
28
40
|
source_file = inspect.getsourcefile(self.func)
|
29
41
|
source_line = inspect.getsourcelines(self.func)[1]
|
@@ -38,27 +50,37 @@ class JIT:
|
|
38
50
|
|
39
51
|
CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
|
40
52
|
Tritonizer().visit(tree)
|
53
|
+
_BinOpSimplifier().visit(tree)
|
41
54
|
ast.fix_missing_locations(tree)
|
42
55
|
|
56
|
+
if self._prettify:
|
57
|
+
name_collector = _SimplifiedNameCollector()
|
58
|
+
name_collector.visit(tree)
|
59
|
+
|
43
60
|
unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
|
61
|
+
dependencies = self._find_dependencies()
|
62
|
+
source = "\n\n".join((unparsed, dependencies)).strip()
|
44
63
|
|
45
|
-
|
46
|
-
|
47
|
-
|
64
|
+
if self._prettify:
|
65
|
+
for original, simplified in name_collector.simplified_names.items():
|
66
|
+
if simplified not in name_collector.simplified_names:
|
67
|
+
source = source.replace(original, simplified)
|
48
68
|
|
49
|
-
|
50
|
-
|
51
|
-
source=temp_file.read(),
|
52
|
-
filename=temp_file_name,
|
53
|
-
mode="exec",
|
69
|
+
source = subprocess.check_output(
|
70
|
+
["ruff", "format", "-"], input=source, encoding="utf-8"
|
54
71
|
)
|
55
72
|
|
56
|
-
|
57
|
-
|
73
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
|
74
|
+
temp_file.write(source.encode("utf-8"))
|
75
|
+
temp_file_name = temp_file.name
|
76
|
+
|
77
|
+
module = type(self)._import_from_path(temp_file_name, temp_file_name)
|
78
|
+
module_vars = vars(module)
|
58
79
|
|
59
80
|
handle = _Handle(
|
60
|
-
|
61
|
-
|
81
|
+
module_vars[self.func.__name__],
|
82
|
+
module_vars[f"launch_{self.func.__name__}"],
|
83
|
+
source,
|
62
84
|
)
|
63
85
|
|
64
86
|
type(self).handles[source_file][source_line] = handle
|
@@ -69,10 +91,30 @@ class JIT:
|
|
69
91
|
module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
|
70
92
|
|
71
93
|
_AliasRestorer().visit(module)
|
94
|
+
collector = _ImportCollector()
|
95
|
+
collector.visit(module)
|
72
96
|
finder = _FunctionDefFinder(self.func.__name__)
|
73
97
|
finder.visit(module)
|
74
98
|
|
75
|
-
return ast.Module(body=[finder.result], type_ignores=[])
|
99
|
+
return ast.Module(body=collector.imports + [finder.result], type_ignores=[])
|
100
|
+
|
101
|
+
def _find_dependencies(self):
|
102
|
+
dependencies = set()
|
103
|
+
|
104
|
+
for obj in self.func.__globals__.values():
|
105
|
+
if isinstance(obj, triton.runtime.JITFunction):
|
106
|
+
dependencies.add(obj.src)
|
107
|
+
|
108
|
+
return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)
|
109
|
+
|
110
|
+
@staticmethod
|
111
|
+
def _import_from_path(module_name, file_path):
|
112
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
113
|
+
module = importlib.util.module_from_spec(spec)
|
114
|
+
sys.modules[module_name] = module
|
115
|
+
spec.loader.exec_module(module)
|
116
|
+
|
117
|
+
return module
|
76
118
|
|
77
119
|
|
78
120
|
class CodeGenerator(ast.NodeTransformer):
|
@@ -102,7 +144,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
102
144
|
self.generic_visit(node)
|
103
145
|
|
104
146
|
for arg in self._args:
|
105
|
-
if not isinstance(arg, Tensor):
|
147
|
+
if not isinstance(arg, Tensor) or arg.ndim == 0:
|
106
148
|
continue
|
107
149
|
|
108
150
|
offsets = arg.offsets()
|
@@ -131,12 +173,17 @@ class CodeGenerator(ast.NodeTransformer):
|
|
131
173
|
|
132
174
|
names_of_args = [arg.names() - {"ninetoothed"} for arg in self._args]
|
133
175
|
names = functools.reduce(lambda x, y: x | y, names_of_args)
|
134
|
-
meta_names = {name for name in names if
|
176
|
+
meta_names = {name for name in names if naming.is_meta(name)}
|
135
177
|
non_meta_names = {name for name in names if name not in meta_names}
|
178
|
+
non_meta_names |= {
|
179
|
+
naming.make_next_power_of_2(name)
|
180
|
+
for name in non_meta_names
|
181
|
+
if naming.is_constexpr(name)
|
182
|
+
}
|
136
183
|
|
137
184
|
node.args = [
|
138
185
|
ast.arg(arg=name)
|
139
|
-
if not
|
186
|
+
if not naming.is_constexpr(name)
|
140
187
|
else ast.arg(arg=name, annotation=attribute("constexpr").node)
|
141
188
|
for name in non_meta_names
|
142
189
|
] + [
|
@@ -287,9 +334,20 @@ class CodeGenerator(ast.NodeTransformer):
|
|
287
334
|
)
|
288
335
|
|
289
336
|
def _generate_launch(self, params, meta):
|
290
|
-
|
291
|
-
|
292
|
-
|
337
|
+
non_next_power_of_2_constexpr_params = [
|
338
|
+
param
|
339
|
+
for param in params
|
340
|
+
if naming.is_constexpr(param) and not naming.is_next_power_of_2(param)
|
341
|
+
]
|
342
|
+
non_next_power_of_2_constexpr_params_without_prefixes = [
|
343
|
+
naming.remove_prefixes(param)
|
344
|
+
for param in non_next_power_of_2_constexpr_params
|
345
|
+
]
|
346
|
+
next_power_of_2_params = [
|
347
|
+
param for param in params if naming.is_next_power_of_2(param)
|
348
|
+
]
|
349
|
+
next_power_of_2_params_without_prefixes = [
|
350
|
+
naming.remove_prefixes(param) for param in next_power_of_2_params
|
293
351
|
]
|
294
352
|
|
295
353
|
launch = ast.FunctionDef(
|
@@ -297,17 +355,33 @@ class CodeGenerator(ast.NodeTransformer):
|
|
297
355
|
args=ast.arguments(
|
298
356
|
posonlyargs=[],
|
299
357
|
args=[ast.arg(arg=arg.original.name) for arg in self._args]
|
300
|
-
+ [
|
358
|
+
+ [
|
359
|
+
ast.arg(arg=param)
|
360
|
+
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
361
|
+
],
|
301
362
|
kwonlyargs=[],
|
302
363
|
defaults=[],
|
303
364
|
),
|
304
365
|
body=[
|
305
366
|
ast.Assign(
|
306
367
|
targets=[ast.Name(id=param, ctx=ast.Store())],
|
307
|
-
value=ast.Name(id=
|
368
|
+
value=ast.Name(id=param_without_prefixes, ctx=ast.Load()),
|
369
|
+
)
|
370
|
+
for param, param_without_prefixes in zip(
|
371
|
+
non_next_power_of_2_constexpr_params,
|
372
|
+
non_next_power_of_2_constexpr_params_without_prefixes,
|
373
|
+
)
|
374
|
+
]
|
375
|
+
+ [
|
376
|
+
ast.Assign(
|
377
|
+
targets=[ast.Name(id=param, ctx=ast.Store())],
|
378
|
+
value=Symbol(
|
379
|
+
f"triton.next_power_of_2({param_without_prefixes})"
|
380
|
+
).node,
|
308
381
|
)
|
309
|
-
for param,
|
310
|
-
|
382
|
+
for param, param_without_prefixes in zip(
|
383
|
+
next_power_of_2_params,
|
384
|
+
next_power_of_2_params_without_prefixes,
|
311
385
|
)
|
312
386
|
]
|
313
387
|
+ [
|
@@ -355,6 +429,9 @@ class CodeGenerator(ast.NodeTransformer):
|
|
355
429
|
|
356
430
|
@staticmethod
|
357
431
|
def _generate_load(tensor, intermediate_indices=()):
|
432
|
+
if tensor.ndim == 0:
|
433
|
+
return Symbol(tensor.original.name).node
|
434
|
+
|
358
435
|
pointers, mask = CodeGenerator._generate_pointers_and_mask(
|
359
436
|
tensor, intermediate_indices
|
360
437
|
)
|
@@ -458,10 +535,41 @@ class Tritonizer(ast.NodeTransformer):
|
|
458
535
|
return node
|
459
536
|
|
460
537
|
|
538
|
+
class _BinOpSimplifier(ast.NodeTransformer):
|
539
|
+
def visit_BinOp(self, node):
|
540
|
+
self.generic_visit(node)
|
541
|
+
|
542
|
+
if isinstance(node.op, ast.Mult):
|
543
|
+
left = Symbol(node.left)
|
544
|
+
right = Symbol(node.right)
|
545
|
+
|
546
|
+
if left == 0 or right == 0:
|
547
|
+
return Symbol(0).node
|
548
|
+
|
549
|
+
if left == 1:
|
550
|
+
return node.right
|
551
|
+
|
552
|
+
if right == 1:
|
553
|
+
return node.left
|
554
|
+
|
555
|
+
return node
|
556
|
+
|
557
|
+
|
558
|
+
class _SimplifiedNameCollector(ast.NodeVisitor):
|
559
|
+
def __init__(self):
|
560
|
+
self.simplified_names = {}
|
561
|
+
|
562
|
+
def visit_Name(self, node):
|
563
|
+
self.generic_visit(node)
|
564
|
+
|
565
|
+
self.simplified_names[node.id] = naming.remove_prefixes(node.id)
|
566
|
+
|
567
|
+
|
461
568
|
class _Handle:
|
462
|
-
def __init__(self, kernel, launch):
|
569
|
+
def __init__(self, kernel, launch, source):
|
463
570
|
self._kernel = kernel
|
464
571
|
self._launch = launch
|
572
|
+
self._source = source
|
465
573
|
|
466
574
|
def __call__(self, *args, **kwargs):
|
467
575
|
return self._launch(*args, **kwargs)
|
@@ -515,6 +623,23 @@ class _AliasRestorer(ast.NodeTransformer):
|
|
515
623
|
return node
|
516
624
|
|
517
625
|
|
626
|
+
class _ImportCollector(ast.NodeVisitor):
|
627
|
+
def __init__(self):
|
628
|
+
super().__init__()
|
629
|
+
|
630
|
+
self.imports = []
|
631
|
+
|
632
|
+
def visit_Import(self, node):
|
633
|
+
self.imports.append(node)
|
634
|
+
|
635
|
+
self.generic_visit(node)
|
636
|
+
|
637
|
+
def visit_ImportFrom(self, node):
|
638
|
+
self.imports.append(node)
|
639
|
+
|
640
|
+
self.generic_visit(node)
|
641
|
+
|
642
|
+
|
518
643
|
class _FunctionDefFinder(ast.NodeVisitor):
|
519
644
|
def __init__(self, name):
|
520
645
|
self._name = name
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
|
4
|
+
def make_constexpr(name):
|
5
|
+
return _add_prefix(name, _CONSTEXPR)
|
6
|
+
|
7
|
+
|
8
|
+
def make_meta(name):
|
9
|
+
return _add_prefix(name, _META)
|
10
|
+
|
11
|
+
|
12
|
+
def make_next_power_of_2(name):
|
13
|
+
return _add_prefix(name, _NEXT_POWER_OF_2)
|
14
|
+
|
15
|
+
|
16
|
+
def is_constexpr(name):
|
17
|
+
return _CONSTEXPR in _find_prefixes(name) or is_meta(name)
|
18
|
+
|
19
|
+
|
20
|
+
def is_meta(name):
|
21
|
+
return _META in _find_prefixes(name)
|
22
|
+
|
23
|
+
|
24
|
+
def is_next_power_of_2(name):
|
25
|
+
return _NEXT_POWER_OF_2 in _find_prefixes(name)
|
26
|
+
|
27
|
+
|
28
|
+
def remove_prefixes(name):
|
29
|
+
return _PREFIX_PATTERN.sub("", name)
|
30
|
+
|
31
|
+
|
32
|
+
_CONSTEXPR = "constexpr"
|
33
|
+
|
34
|
+
_META = "meta"
|
35
|
+
|
36
|
+
_NEXT_POWER_OF_2 = "next_power_of_2"
|
37
|
+
|
38
|
+
_PREFIX_PATTERN = re.compile(r"ninetoothed_((?!_).*?)_prefix_")
|
39
|
+
|
40
|
+
|
41
|
+
def _add_prefix(name, string):
|
42
|
+
return f"{_make_prefix(string)}{name}"
|
43
|
+
|
44
|
+
|
45
|
+
def _make_prefix(string):
|
46
|
+
return f"ninetoothed_{string}_prefix_"
|
47
|
+
|
48
|
+
|
49
|
+
def _find_prefixes(name):
|
50
|
+
return set(_PREFIX_PATTERN.findall(name))
|
@@ -1,7 +1,10 @@
|
|
1
1
|
import ast
|
2
2
|
import inspect
|
3
|
+
import numbers
|
3
4
|
import types
|
4
5
|
|
6
|
+
import ninetoothed.naming as naming
|
7
|
+
|
5
8
|
|
6
9
|
class Symbol:
|
7
10
|
def __init__(self, expr, constexpr=None, meta=None):
|
@@ -28,18 +31,31 @@ class Symbol:
|
|
28
31
|
if constexpr is False:
|
29
32
|
raise ValueError("Non-constexpr meta symbol is not supported.")
|
30
33
|
|
31
|
-
self._node.id =
|
34
|
+
self._node.id = naming.make_meta(self._node.id)
|
32
35
|
|
33
36
|
if constexpr:
|
34
|
-
self._node.id =
|
37
|
+
self._node.id = naming.make_constexpr(self._node.id)
|
38
|
+
|
39
|
+
def __eq__(self, other):
|
40
|
+
if isinstance(self._node, ast.Constant):
|
41
|
+
if isinstance(other, Symbol) and isinstance(other._node, ast.Constant):
|
42
|
+
return self._node.value == other._node.value
|
43
|
+
|
44
|
+
if isinstance(other, numbers.Number):
|
45
|
+
return self._node.value == other
|
46
|
+
|
47
|
+
return False
|
48
|
+
|
49
|
+
def __hash__(self):
|
50
|
+
return id(self)
|
35
51
|
|
36
52
|
def __add__(self, other):
|
37
53
|
other = type(self)(other)
|
38
54
|
|
39
|
-
if
|
55
|
+
if self == 0:
|
40
56
|
return other
|
41
57
|
|
42
|
-
if
|
58
|
+
if other == 0:
|
43
59
|
return self
|
44
60
|
|
45
61
|
return type(self)(ast.BinOp(left=self._node, op=ast.Add(), right=other._node))
|
@@ -47,19 +63,30 @@ class Symbol:
|
|
47
63
|
def __radd__(self, other):
|
48
64
|
return self.__add__(other)
|
49
65
|
|
50
|
-
def
|
66
|
+
def __sub__(self, other):
|
51
67
|
other = type(self)(other)
|
52
68
|
|
53
|
-
if
|
54
|
-
return
|
69
|
+
if self == 0:
|
70
|
+
return -other
|
55
71
|
|
56
|
-
if
|
72
|
+
if other == 0:
|
73
|
+
return self
|
74
|
+
|
75
|
+
return type(self)(ast.BinOp(left=self._node, op=ast.Sub(), right=other._node))
|
76
|
+
|
77
|
+
def __rsub__(self, other):
|
78
|
+
return type(self)(other).__sub__(self)
|
79
|
+
|
80
|
+
def __mul__(self, other):
|
81
|
+
other = type(self)(other)
|
82
|
+
|
83
|
+
if self == 0 or other == 0:
|
57
84
|
return type(self)(0)
|
58
85
|
|
59
|
-
if
|
86
|
+
if self == 1:
|
60
87
|
return other
|
61
88
|
|
62
|
-
if
|
89
|
+
if other == 1:
|
63
90
|
return self
|
64
91
|
|
65
92
|
return type(self)(ast.BinOp(left=self._node, op=ast.Mult(), right=other._node))
|
@@ -136,40 +163,8 @@ class Symbol:
|
|
136
163
|
return SliceSimplifier().visit(self._node)
|
137
164
|
|
138
165
|
@staticmethod
|
139
|
-
def
|
140
|
-
return
|
141
|
-
|
142
|
-
@staticmethod
|
143
|
-
def is_meta(name):
|
144
|
-
return name.startswith(Symbol._meta_prefix())
|
145
|
-
|
146
|
-
@staticmethod
|
147
|
-
def remove_prefix(name):
|
148
|
-
if name.startswith(Symbol._constexpr_prefix()):
|
149
|
-
return name.removeprefix(Symbol._constexpr_prefix())
|
150
|
-
|
151
|
-
if name.startswith(Symbol._meta_prefix()):
|
152
|
-
return name.removeprefix(Symbol._meta_prefix())
|
153
|
-
|
154
|
-
@staticmethod
|
155
|
-
def _create_constexpr(name):
|
156
|
-
return f"{Symbol._constexpr_prefix()}{name}"
|
157
|
-
|
158
|
-
@staticmethod
|
159
|
-
def _create_meta(name):
|
160
|
-
return f"{Symbol._meta_prefix()}{name}"
|
161
|
-
|
162
|
-
@staticmethod
|
163
|
-
def _constexpr_prefix():
|
164
|
-
return f"{Symbol._ninetoothed_prefix()}constexpr_"
|
165
|
-
|
166
|
-
@staticmethod
|
167
|
-
def _meta_prefix():
|
168
|
-
return f"{Symbol._ninetoothed_prefix()}meta_"
|
169
|
-
|
170
|
-
@staticmethod
|
171
|
-
def _ninetoothed_prefix():
|
172
|
-
return "_ninetoothed_"
|
166
|
+
def is_name(object):
|
167
|
+
return isinstance(object, Symbol) and isinstance(object.node, ast.Name)
|
173
168
|
|
174
169
|
|
175
170
|
class _FindAndReplacer(ast.NodeTransformer):
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import itertools
|
2
2
|
import re
|
3
3
|
|
4
|
+
import ninetoothed.naming as naming
|
4
5
|
from ninetoothed.language import call
|
5
6
|
from ninetoothed.symbol import Symbol
|
6
7
|
|
@@ -15,13 +16,15 @@ class Tensor:
|
|
15
16
|
dtype=None,
|
16
17
|
strides=None,
|
17
18
|
other=None,
|
19
|
+
name=None,
|
18
20
|
original=None,
|
19
21
|
):
|
20
|
-
type(self).num_instances += 1
|
21
|
-
|
22
22
|
self.dtype = dtype
|
23
23
|
|
24
|
-
|
24
|
+
if name is not None:
|
25
|
+
self.name = name
|
26
|
+
else:
|
27
|
+
self.name = f"_ninetoothed_tensor_{type(self).num_instances}"
|
25
28
|
|
26
29
|
if ndim is not None:
|
27
30
|
self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
|
@@ -41,29 +44,41 @@ class Tensor:
|
|
41
44
|
else:
|
42
45
|
self.original = self
|
43
46
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
+
type(self).num_instances += 1
|
48
|
+
|
49
|
+
def tile(self, tile_shape, strides=None, dilation=None):
|
50
|
+
if strides is None:
|
51
|
+
strides = [-1 for _ in tile_shape]
|
52
|
+
|
53
|
+
if dilation is None:
|
54
|
+
dilation = [1 for _ in tile_shape]
|
47
55
|
|
48
56
|
outer_shape = []
|
49
57
|
outer_strides = []
|
50
58
|
inner_shape = []
|
51
59
|
inner_strides = []
|
52
60
|
|
53
|
-
for
|
54
|
-
self.shape, self.strides, tile_shape,
|
61
|
+
for self_size, self_stride, tile_size, stride, spacing in zip(
|
62
|
+
self.shape, self.strides, tile_shape, strides, dilation
|
55
63
|
):
|
56
64
|
if tile_size == -1:
|
57
|
-
tile_size =
|
65
|
+
tile_size = self_size
|
66
|
+
|
67
|
+
if stride == -1:
|
68
|
+
stride = tile_size
|
58
69
|
|
59
|
-
new_size =
|
70
|
+
new_size = (
|
71
|
+
call("cdiv", self_size - spacing * (tile_size - 1) - 1, stride) + 1
|
72
|
+
if stride != 0
|
73
|
+
else -1
|
74
|
+
)
|
60
75
|
outer_shape.append(new_size)
|
61
76
|
|
62
|
-
new_stride =
|
77
|
+
new_stride = self_stride * stride // spacing
|
63
78
|
outer_strides.append(new_stride)
|
64
79
|
|
65
80
|
inner_shape.append(tile_size)
|
66
|
-
next_stride =
|
81
|
+
next_stride = self_stride * spacing
|
67
82
|
inner_strides.append(next_stride)
|
68
83
|
|
69
84
|
return type(self)(
|
@@ -103,6 +118,9 @@ class Tensor:
|
|
103
118
|
)
|
104
119
|
|
105
120
|
def names(self):
|
121
|
+
if self.ndim == 0:
|
122
|
+
return {self.original.name}
|
123
|
+
|
106
124
|
return (
|
107
125
|
{self.original.pointer_string()}
|
108
126
|
| {
|
@@ -112,6 +130,7 @@ class Tensor:
|
|
112
130
|
for name in value.names()
|
113
131
|
}
|
114
132
|
| (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
|
133
|
+
| (self.original.names() if self.original is not self else set())
|
115
134
|
)
|
116
135
|
|
117
136
|
def offsets(self, indices=None):
|
@@ -160,7 +179,14 @@ class Tensor:
|
|
160
179
|
|
161
180
|
if isinstance(curr, type(self)):
|
162
181
|
for dim in range(curr.ndim):
|
163
|
-
|
182
|
+
size = curr.shape[dim]
|
183
|
+
|
184
|
+
if Symbol.is_name(size):
|
185
|
+
name = size.node.id
|
186
|
+
if not naming.is_meta(name):
|
187
|
+
size = naming.make_next_power_of_2(name)
|
188
|
+
|
189
|
+
indices.append(call("arange", 0, size))
|
164
190
|
|
165
191
|
return tuple(indices)
|
166
192
|
|
@@ -237,7 +263,7 @@ class Tensor:
|
|
237
263
|
def _calculate_default_strides(shape):
|
238
264
|
strides = [1]
|
239
265
|
|
240
|
-
for size in shape[1:]:
|
266
|
+
for size in reversed(shape[1:]):
|
241
267
|
strides.append(size * strides[-1])
|
242
268
|
|
243
269
|
return reversed(strides)
|
@@ -0,0 +1,104 @@
|
|
1
|
+
import random
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
import ninetoothed
|
6
|
+
import ninetoothed.language as ntl
|
7
|
+
from ninetoothed import Symbol, Tensor
|
8
|
+
from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
|
9
|
+
|
10
|
+
|
11
|
+
def addmm(input, mat1, mat2, beta=1, alpha=1):
|
12
|
+
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
|
13
|
+
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
|
14
|
+
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)
|
15
|
+
|
16
|
+
input_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
|
17
|
+
|
18
|
+
output_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
|
19
|
+
|
20
|
+
mat1_tiled = (
|
21
|
+
Tensor(2)
|
22
|
+
.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
|
23
|
+
.tile((1, -1))
|
24
|
+
.expand((-1, output_tiled.shape[1]))
|
25
|
+
)
|
26
|
+
mat1_tiled.dtype = mat1_tiled.dtype.squeeze(0)
|
27
|
+
|
28
|
+
mat2_tiled = (
|
29
|
+
Tensor(2)
|
30
|
+
.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
|
31
|
+
.tile((-1, 1))
|
32
|
+
.expand((output_tiled.shape[0], -1))
|
33
|
+
)
|
34
|
+
mat2_tiled.dtype = mat2_tiled.dtype.squeeze(1)
|
35
|
+
|
36
|
+
@ninetoothed.jit
|
37
|
+
def addmm_kernel(
|
38
|
+
input: input_tiled,
|
39
|
+
mat1: mat1_tiled,
|
40
|
+
mat2: mat2_tiled,
|
41
|
+
beta: Tensor(0),
|
42
|
+
alpha: Tensor(0),
|
43
|
+
output: output_tiled,
|
44
|
+
):
|
45
|
+
accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
|
46
|
+
for k in range(mat1.shape[0]):
|
47
|
+
accumulator += ntl.dot(mat1[k], mat2[k])
|
48
|
+
output = beta * input + alpha * accumulator.to(ntl.float16)
|
49
|
+
|
50
|
+
output = torch.empty(
|
51
|
+
(mat1.shape[0], mat2.shape[1]), device=mat1.device, dtype=torch.float16
|
52
|
+
)
|
53
|
+
|
54
|
+
addmm_kernel(input, mat1, mat2, beta, alpha, output)
|
55
|
+
|
56
|
+
return output
|
57
|
+
|
58
|
+
|
59
|
+
@skip_if_cuda_not_available
|
60
|
+
class TestCUDA:
|
61
|
+
@classmethod
|
62
|
+
def setup_class(cls):
|
63
|
+
torch.manual_seed(0)
|
64
|
+
|
65
|
+
shape = (512, 512)
|
66
|
+
|
67
|
+
cls.input = torch.randn(shape, device="cuda")
|
68
|
+
cls.mat1 = torch.randn(shape, device="cuda")
|
69
|
+
cls.mat2 = torch.randn(shape, device="cuda")
|
70
|
+
cls.beta = random.uniform(0, 1)
|
71
|
+
cls.alpha = random.uniform(0, 1)
|
72
|
+
|
73
|
+
def test_fp16(self):
|
74
|
+
input = type(self).input.to(torch.float16)
|
75
|
+
mat1 = type(self).mat1.to(torch.float16)
|
76
|
+
mat2 = type(self).mat2.to(torch.float16)
|
77
|
+
beta = type(self).beta
|
78
|
+
alpha = type(self).alpha
|
79
|
+
|
80
|
+
assert torch.allclose(
|
81
|
+
addmm(input, mat1, mat2, beta=beta, alpha=alpha),
|
82
|
+
torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha),
|
83
|
+
atol=0.075,
|
84
|
+
)
|
85
|
+
|
86
|
+
@skip_if_float8_e5m2_not_supported
|
87
|
+
def test_fp8(self):
|
88
|
+
input = type(self).input.to(torch.float8_e5m2)
|
89
|
+
mat1 = type(self).mat1.to(torch.float8_e5m2)
|
90
|
+
mat2 = type(self).mat2.T.to(torch.float8_e5m2)
|
91
|
+
beta = type(self).beta
|
92
|
+
alpha = type(self).alpha
|
93
|
+
|
94
|
+
assert torch.allclose(
|
95
|
+
addmm(input, mat1, mat2, beta=beta, alpha=alpha),
|
96
|
+
torch.addmm(
|
97
|
+
input.to(torch.float16),
|
98
|
+
mat1.to(torch.float16),
|
99
|
+
mat2.to(torch.float16),
|
100
|
+
beta=beta,
|
101
|
+
alpha=alpha,
|
102
|
+
),
|
103
|
+
atol=0.125,
|
104
|
+
)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
import ninetoothed.naming as naming
|
2
|
+
|
3
|
+
|
4
|
+
def test_make_constexpr():
|
5
|
+
assert naming.make_constexpr(_NAME) == f"ninetoothed_constexpr_prefix_{_NAME}"
|
6
|
+
|
7
|
+
|
8
|
+
def test_make_meta():
|
9
|
+
assert naming.make_meta(_NAME) == f"ninetoothed_meta_prefix_{_NAME}"
|
10
|
+
|
11
|
+
|
12
|
+
def test_make_next_power_of_2():
|
13
|
+
assert (
|
14
|
+
naming.make_next_power_of_2(_NAME)
|
15
|
+
== f"ninetoothed_next_power_of_2_prefix_{_NAME}"
|
16
|
+
)
|
17
|
+
assert (
|
18
|
+
naming.make_next_power_of_2(naming.make_constexpr(_NAME))
|
19
|
+
== f"ninetoothed_next_power_of_2_prefix_ninetoothed_constexpr_prefix_{_NAME}"
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
def test_is_constexpr():
|
24
|
+
assert naming.is_constexpr(naming.make_constexpr(_NAME))
|
25
|
+
assert naming.is_constexpr(naming.make_meta(_NAME))
|
26
|
+
|
27
|
+
|
28
|
+
def test_is_meta():
|
29
|
+
assert naming.is_meta(naming.make_meta(_NAME))
|
30
|
+
|
31
|
+
|
32
|
+
def test_is_next_power_of_2():
|
33
|
+
assert naming.is_next_power_of_2(naming.make_next_power_of_2(_NAME))
|
34
|
+
assert naming.is_next_power_of_2(
|
35
|
+
naming.make_next_power_of_2(naming.make_constexpr(_NAME))
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
def test_remove_prefixes():
|
40
|
+
assert naming.remove_prefixes(naming.make_constexpr(_NAME)) == _NAME
|
41
|
+
assert naming.remove_prefixes(naming.make_meta(_NAME)) == _NAME
|
42
|
+
assert naming.remove_prefixes(naming.make_next_power_of_2(_NAME)) == _NAME
|
43
|
+
assert (
|
44
|
+
naming.remove_prefixes(
|
45
|
+
naming.make_next_power_of_2(naming.make_constexpr(_NAME))
|
46
|
+
)
|
47
|
+
== _NAME
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
_NAME = "ninetoothed_name"
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import torch
|
2
|
-
import triton
|
3
2
|
|
4
3
|
import ninetoothed
|
5
4
|
import ninetoothed.language as ntl
|
@@ -22,7 +21,7 @@ def softmax(input):
|
|
22
21
|
|
23
22
|
output = torch.empty_like(input)
|
24
23
|
|
25
|
-
softmax_kernel(input, output, BLOCK_SIZE=
|
24
|
+
softmax_kernel(input, output, BLOCK_SIZE=input.shape[-1])
|
26
25
|
|
27
26
|
return output
|
28
27
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|