ninetoothed 0.2.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/PKG-INFO +9 -5
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/README.md +7 -4
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/docs/README.zh.md +7 -4
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/pyproject.toml +2 -1
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/src/ninetoothed/jit.py +127 -39
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/src/ninetoothed/language.py +4 -1
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/src/ninetoothed/symbol.py +57 -5
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/src/ninetoothed/tensor.py +96 -44
- ninetoothed-0.4.0/src/ninetoothed/torchifier.py +38 -0
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/tests/test_matmul.py +5 -2
- ninetoothed-0.4.0/tests/test_softmax.py +41 -0
- ninetoothed-0.2.0/src/ninetoothed/torchifier.py +0 -34
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/.github/workflows/pytest.yml +0 -0
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/.github/workflows/ruff.yml +0 -0
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/.gitignore +0 -0
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/LICENSE +0 -0
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/requirements.txt +0 -0
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/src/ninetoothed/__init__.py +0 -0
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/tests/__init__.py +0 -0
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/tests/skippers.py +0 -0
- {ninetoothed-0.2.0 → ninetoothed-0.4.0}/tests/test_add.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
|
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
|
+
Requires-Dist: triton>=3.0.0
|
13
14
|
Description-Content-Type: text/markdown
|
14
15
|
|
15
16
|
# NineToothed
|
@@ -26,7 +27,7 @@ We can use `pip` to install `ninetoothed`.
|
|
26
27
|
pip install ninetoothed
|
27
28
|
```
|
28
29
|
|
29
|
-
After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install
|
30
|
+
After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install a deep learning framework supported by `ninetoothed`. For trial purposes, we recommend installing `torch`.
|
30
31
|
|
31
32
|
## Usage
|
32
33
|
|
@@ -64,16 +65,19 @@ c_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
|
|
64
65
|
a_tiled = a_tiled.expand((-1, c_tiled.shape[1]))
|
65
66
|
b_tiled = b_tiled.expand((c_tiled.shape[0], -1))
|
66
67
|
|
68
|
+
a_tiled.dtype = a_tiled.dtype.squeeze(0)
|
69
|
+
b_tiled.dtype = b_tiled.dtype.squeeze(1)
|
70
|
+
|
67
71
|
@ninetoothed.jit
|
68
72
|
def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
69
73
|
accumulator = ninetoothed.language.zeros(
|
70
74
|
c.shape, dtype=ninetoothed.language.float32
|
71
75
|
)
|
72
|
-
for k in range(a.shape[
|
73
|
-
accumulator
|
76
|
+
for k in range(a.shape[0]):
|
77
|
+
accumulator += ninetoothed.language.dot(a[k], b[k])
|
74
78
|
c = accumulator.to(ninetoothed.language.float16)
|
75
79
|
```
|
76
80
|
|
77
|
-
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$.
|
81
|
+
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
|
78
82
|
|
79
83
|
With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of B, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
|
@@ -12,7 +12,7 @@ We can use `pip` to install `ninetoothed`.
|
|
12
12
|
pip install ninetoothed
|
13
13
|
```
|
14
14
|
|
15
|
-
After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install
|
15
|
+
After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install a deep learning framework supported by `ninetoothed`. For trial purposes, we recommend installing `torch`.
|
16
16
|
|
17
17
|
## Usage
|
18
18
|
|
@@ -50,16 +50,19 @@ c_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
|
|
50
50
|
a_tiled = a_tiled.expand((-1, c_tiled.shape[1]))
|
51
51
|
b_tiled = b_tiled.expand((c_tiled.shape[0], -1))
|
52
52
|
|
53
|
+
a_tiled.dtype = a_tiled.dtype.squeeze(0)
|
54
|
+
b_tiled.dtype = b_tiled.dtype.squeeze(1)
|
55
|
+
|
53
56
|
@ninetoothed.jit
|
54
57
|
def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
55
58
|
accumulator = ninetoothed.language.zeros(
|
56
59
|
c.shape, dtype=ninetoothed.language.float32
|
57
60
|
)
|
58
|
-
for k in range(a.shape[
|
59
|
-
accumulator
|
61
|
+
for k in range(a.shape[0]):
|
62
|
+
accumulator += ninetoothed.language.dot(a[k], b[k])
|
60
63
|
c = accumulator.to(ninetoothed.language.float16)
|
61
64
|
```
|
62
65
|
|
63
|
-
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$.
|
66
|
+
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
|
64
67
|
|
65
68
|
With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of B, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
|
@@ -12,7 +12,7 @@
|
|
12
12
|
pip install ninetoothed
|
13
13
|
```
|
14
14
|
|
15
|
-
成功运行完以上两个命令之后,`ninetoothed` 就被安装好了。但是除了 `ninetoothed`
|
15
|
+
成功运行完以上两个命令之后,`ninetoothed` 就被安装好了。但是除了 `ninetoothed` 的本体之外,如果我们想要真正发挥它的作用,至少还需要安装一个 `ninetoothed` 所支持的深度学习框架。以尝试为目的的话,我们推荐安装 `torch`。
|
16
16
|
|
17
17
|
## 使用
|
18
18
|
|
@@ -50,16 +50,19 @@ c_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
|
|
50
50
|
a_tiled = a_tiled.expand((-1, c_tiled.shape[1]))
|
51
51
|
b_tiled = b_tiled.expand((c_tiled.shape[0], -1))
|
52
52
|
|
53
|
+
a_tiled.dtype = a_tiled.dtype.squeeze(0)
|
54
|
+
b_tiled.dtype = b_tiled.dtype.squeeze(1)
|
55
|
+
|
53
56
|
@ninetoothed.jit
|
54
57
|
def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
55
58
|
accumulator = ninetoothed.language.zeros(
|
56
59
|
c.shape, dtype=ninetoothed.language.float32
|
57
60
|
)
|
58
|
-
for k in range(a.shape[
|
59
|
-
accumulator
|
61
|
+
for k in range(a.shape[0]):
|
62
|
+
accumulator += ninetoothed.language.dot(a[k], b[k])
|
60
63
|
c = accumulator.to(ninetoothed.language.float16)
|
61
64
|
```
|
62
65
|
|
63
|
-
对于矩阵乘法来说,我们也有三个参数张量,但是分块的方式肯定比向量加法要复杂一些。我们将三个矩阵分别记作 $A$、$B$、$C$,其中 $A$ 和 $B$ 为输入,$C$ 为输出。其中 $C$ 的分块操作很简单,我们只需要按照行和列,将其分成大小为 `(BLOCK_SIZE_M, BLOCK_SIZE_N)` 的块即可,这样只要每个这样的块都算出了结果,整个 $C$ 也就都算出了结果。那么该如何分 $A$ 和 $B$ 呢?答案是再引入一个元参数 `BLOCK_SIZE_K`,这样我们就可以把 $A$ 分成 `(BLOCK_SIZE_M, BLOCK_SIZE_K)` 大小的块,把 $B$ 分成 `(BLOCK_SIZE_K, BLOCK_SIZE_N)` 的块。但是对于矩阵乘法,$A$ 和 $B$ 并不是块块对应,而是需要对应 $A$ 的每一行和 $B$ 的每一列,所以我们还需要继续 `tile`,把 $A$ 和 $B$ 进一步分成以行为单位和以列为单位的块。到目前为止,我们有了一堆 $A$ 的行块和 $B$ 的列块,但是对于每一个 $A$ 的行块,我们都要对应 $B$ 的每一个列块。这个时候,我们就需要进行 `expand` 了,我们把 $A$ 的行块沿着列 `expand` 成 $C$ 的列数那么多列,把 $B$ 的列块沿着行 `expand` 成 $C$ 的行数那么多行。这样,我们就成功地将 $A$、$B$、$C$ 三者都分好了块,并且对于每一个 $C$ 的块,我们都有对应好的 $A$ 的行块和 $B$
|
66
|
+
对于矩阵乘法来说,我们也有三个参数张量,但是分块的方式肯定比向量加法要复杂一些。我们将三个矩阵分别记作 $A$、$B$、$C$,其中 $A$ 和 $B$ 为输入,$C$ 为输出。其中 $C$ 的分块操作很简单,我们只需要按照行和列,将其分成大小为 `(BLOCK_SIZE_M, BLOCK_SIZE_N)` 的块即可,这样只要每个这样的块都算出了结果,整个 $C$ 也就都算出了结果。那么该如何分 $A$ 和 $B$ 呢?答案是再引入一个元参数 `BLOCK_SIZE_K`,这样我们就可以把 $A$ 分成 `(BLOCK_SIZE_M, BLOCK_SIZE_K)` 大小的块,把 $B$ 分成 `(BLOCK_SIZE_K, BLOCK_SIZE_N)` 的块。但是对于矩阵乘法,$A$ 和 $B$ 并不是块块对应,而是需要对应 $A$ 的每一行和 $B$ 的每一列,所以我们还需要继续 `tile`,把 $A$ 和 $B$ 进一步分成以行为单位和以列为单位的块。到目前为止,我们有了一堆 $A$ 的行块和 $B$ 的列块,但是对于每一个 $A$ 的行块,我们都要对应 $B$ 的每一个列块。这个时候,我们就需要进行 `expand` 了,我们把 $A$ 的行块沿着列 `expand` 成 $C$ 的列数那么多列,把 $B$ 的列块沿着行 `expand` 成 $C$ 的行数那么多行。这样,我们就成功地将 $A$、$B$、$C$ 三者都分好了块,并且对于每一个 $C$ 的块,我们都有对应好的 $A$ 的行块和 $B$ 的列块。其实我们的元操作到此为止,已经能够编写出核函数了,但是我们发现,刚才所提到的行块和列块所在的层级,是二维的,而且大小是 `(1, ...)` 和 `(..., 1)` 这样的形式。也就是说,如果不进行其它操作,那么我们访问行块和列块的方式就得是 `a[0, k]` 和 `b[k, 0]`,如果我们想要依靠 `a` 找到 `k` 的范围,那就得是 `a.shape[1]`。但是我们知道,大小为 `1` 的维度,其实完全可以被去掉,这就是为什么我们加了两行 `squeeze`,其中的 `dtype` 是数据类型的意思,在 PyTorch 中一般可以是某些整数类型或者浮点类型之类的,比如 `torch.float32`,但是由于九齿当中可以进行 `tile` 等元操作,所以 `dtype` 也可以是 `Tensor`。也就是说,在九齿当中,存在着“存储张量的张量”这样的概念。总而言之,这两行就是对最外层张量所存储的下一层的张量进行操作,把大小为 `1` 的维度去掉了,这样,我们在访问行块和列块时就可以使用 `a[k]` 和 `b[k]`,找 `k` 的范围时也可以使用 `a.shape[0]` 了。
|
64
67
|
|
65
68
|
对应好了分块,后续的部分就简单多了。在函数体当中,我们定义了一个 `accumulator`,用于累加中间结果,之后就遍历了对应好的 $A$ 的行块和 $B$ 的列块,并且把他们相乘的结果累加到了 `accumulator` 当中,最后再将 `accumulator` 放到了对应的 $C$ 的分块当中。由于参数张量被分成的每一块都被执行了这样的操作,因此即便对于整体而言,乘法也被完成了。
|
@@ -4,10 +4,11 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "ninetoothed"
|
7
|
-
version = "0.
|
7
|
+
version = "0.4.0"
|
8
8
|
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
|
9
9
|
description = "A domain-specific language based on Triton but providing higher-level abstraction."
|
10
10
|
readme = "README.md"
|
11
|
+
dependencies = ["triton>=3.0.0"]
|
11
12
|
requires-python = ">=3.10"
|
12
13
|
classifiers = [
|
13
14
|
"Programming Language :: Python :: 3",
|
@@ -6,6 +6,8 @@ import itertools
|
|
6
6
|
import math
|
7
7
|
import tempfile
|
8
8
|
|
9
|
+
import triton
|
10
|
+
|
9
11
|
from ninetoothed.language import attribute, call
|
10
12
|
from ninetoothed.symbol import Symbol
|
11
13
|
from ninetoothed.tensor import Tensor
|
@@ -103,13 +105,24 @@ class CodeGenerator(ast.NodeTransformer):
|
|
103
105
|
if not isinstance(arg, Tensor):
|
104
106
|
continue
|
105
107
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
108
|
+
offsets = arg.offsets()
|
109
|
+
|
110
|
+
initializations = {
|
111
|
+
type(self)._name_for_offsets(arg, dim): offs
|
112
|
+
for dim, offs in enumerate(offsets)
|
113
|
+
} | {
|
114
|
+
type(self)._name_for_pointers(arg): arg.original.pointer_string()
|
115
|
+
+ sum(
|
116
|
+
type(self)._name_for_offsets(arg, dim)[
|
117
|
+
type(self)._generate_slices(arg, dim)
|
118
|
+
]
|
119
|
+
* stride
|
120
|
+
for dim, stride in enumerate(arg.original.strides)
|
121
|
+
)
|
122
|
+
}
|
123
|
+
|
124
|
+
for target, value in reversed(initializations.items()):
|
125
|
+
node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
|
113
126
|
|
114
127
|
return node
|
115
128
|
|
@@ -124,7 +137,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
124
137
|
node.args = [
|
125
138
|
ast.arg(arg=name)
|
126
139
|
if not Symbol.is_constexpr(name)
|
127
|
-
else ast.arg(arg=name, annotation=attribute("constexpr"))
|
140
|
+
else ast.arg(arg=name, annotation=attribute("constexpr").node)
|
128
141
|
for name in non_meta_names
|
129
142
|
] + [
|
130
143
|
ast.arg(arg=name, annotation=attribute("constexpr").node)
|
@@ -147,15 +160,13 @@ class CodeGenerator(ast.NodeTransformer):
|
|
147
160
|
value = self._context[node.value.id]
|
148
161
|
|
149
162
|
if isinstance(value, Tensor):
|
150
|
-
|
163
|
+
return type(self)._generate_load(
|
151
164
|
value,
|
152
|
-
node.slice.elts
|
165
|
+
intermediate_indices=node.slice.elts
|
153
166
|
if isinstance(node.slice, ast.Tuple)
|
154
167
|
else (node.slice,),
|
155
168
|
)
|
156
169
|
|
157
|
-
return call("load", pointers).node
|
158
|
-
|
159
170
|
self.generic_visit(node)
|
160
171
|
|
161
172
|
return node
|
@@ -177,9 +188,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
177
188
|
self.generic_visit(node)
|
178
189
|
|
179
190
|
if node.id in self._context and isinstance(node.ctx, ast.Load):
|
180
|
-
return
|
181
|
-
"load", type(self)._create_pointers(self._context[node.id], ()).node
|
182
|
-
).node
|
191
|
+
return type(self)._generate_load(self._context[node.id])
|
183
192
|
|
184
193
|
return node
|
185
194
|
|
@@ -191,11 +200,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
191
200
|
self.generic_visit(node)
|
192
201
|
|
193
202
|
return ast.Expr(
|
194
|
-
|
195
|
-
"store",
|
196
|
-
type(self)._create_pointers(self._context[target.id], ()).node,
|
197
|
-
node.value,
|
198
|
-
).node
|
203
|
+
type(self)._generate_store(self._context[target.id], node.value)
|
199
204
|
)
|
200
205
|
elif (
|
201
206
|
isinstance(target, ast.Subscript)
|
@@ -208,19 +213,14 @@ class CodeGenerator(ast.NodeTransformer):
|
|
208
213
|
if isinstance(value, Tensor):
|
209
214
|
self.generic_visit(node)
|
210
215
|
|
211
|
-
pointers = type(self)._create_pointers(
|
212
|
-
value,
|
213
|
-
target.slice.elts
|
214
|
-
if isinstance(target.slice, ast.Tuple)
|
215
|
-
else (target.slice,),
|
216
|
-
)
|
217
|
-
|
218
216
|
return ast.Expr(
|
219
|
-
|
220
|
-
|
221
|
-
pointers.node,
|
217
|
+
type(self)._generate_store(
|
218
|
+
value,
|
222
219
|
node.value,
|
223
|
-
|
220
|
+
intermediate_indices=target.slice.elts
|
221
|
+
if isinstance(target.slice, ast.Tuple)
|
222
|
+
else (target.slice,),
|
223
|
+
)
|
224
224
|
)
|
225
225
|
|
226
226
|
self.generic_visit(node)
|
@@ -228,6 +228,13 @@ class CodeGenerator(ast.NodeTransformer):
|
|
228
228
|
return node
|
229
229
|
|
230
230
|
def _generate_autotune(self, params, meta):
|
231
|
+
device = triton.runtime.driver.active.get_current_device()
|
232
|
+
properties = triton.runtime.driver.active.utils.get_device_properties(device)
|
233
|
+
max_shared_mem = properties["max_shared_mem"]
|
234
|
+
|
235
|
+
num_warps = 8
|
236
|
+
num_stages = max_shared_mem // 2**15
|
237
|
+
|
231
238
|
configs = [
|
232
239
|
ast.Call(
|
233
240
|
func=ast.Attribute(
|
@@ -241,7 +248,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
241
248
|
values=[ast.Constant(value=value) for value in values],
|
242
249
|
)
|
243
250
|
],
|
244
|
-
keywords=[
|
251
|
+
keywords=[
|
252
|
+
ast.keyword(arg="num_warps", value=ast.Constant(value=num_warps)),
|
253
|
+
ast.keyword(arg="num_stages", value=ast.Constant(value=num_stages)),
|
254
|
+
],
|
245
255
|
)
|
246
256
|
for values in itertools.product(self._POWER_OF_TWOS, repeat=len(meta))
|
247
257
|
if self._MIN_PRODUCT <= math.prod(values) <= self._MAX_PRODUCT
|
@@ -268,7 +278,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
268
278
|
elts=[
|
269
279
|
ast.Constant(value=param)
|
270
280
|
for param in params
|
271
|
-
if not Tensor.
|
281
|
+
if not Tensor.pointer_pattern().fullmatch(param)
|
272
282
|
],
|
273
283
|
ctx=ast.Load(),
|
274
284
|
),
|
@@ -277,15 +287,30 @@ class CodeGenerator(ast.NodeTransformer):
|
|
277
287
|
)
|
278
288
|
|
279
289
|
def _generate_launch(self, params, meta):
|
290
|
+
constexpr_params = [param for param in params if Symbol.is_constexpr(param)]
|
291
|
+
constexpr_params_without_prefixes = [
|
292
|
+
Symbol.remove_prefix(param) for param in constexpr_params
|
293
|
+
]
|
294
|
+
|
280
295
|
launch = ast.FunctionDef(
|
281
296
|
name=f"launch_{self._func_def.name}",
|
282
297
|
args=ast.arguments(
|
283
298
|
posonlyargs=[],
|
284
|
-
args=[ast.arg(arg.name) for arg in self._args]
|
299
|
+
args=[ast.arg(arg=arg.original.name) for arg in self._args]
|
300
|
+
+ [ast.arg(arg=param) for param in constexpr_params_without_prefixes],
|
285
301
|
kwonlyargs=[],
|
286
302
|
defaults=[],
|
287
303
|
),
|
288
304
|
body=[
|
305
|
+
ast.Assign(
|
306
|
+
targets=[ast.Name(id=param, ctx=ast.Store())],
|
307
|
+
value=ast.Name(id=param_without_prefix, ctx=ast.Load()),
|
308
|
+
)
|
309
|
+
for param, param_without_prefix in zip(
|
310
|
+
constexpr_params, constexpr_params_without_prefixes
|
311
|
+
)
|
312
|
+
]
|
313
|
+
+ [
|
289
314
|
ast.Expr(
|
290
315
|
ast.Call(
|
291
316
|
func=ast.Subscript(
|
@@ -329,13 +354,76 @@ class CodeGenerator(ast.NodeTransformer):
|
|
329
354
|
return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
|
330
355
|
|
331
356
|
@staticmethod
|
332
|
-
def
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
357
|
+
def _generate_load(tensor, intermediate_indices=()):
|
358
|
+
pointers, mask = CodeGenerator._generate_pointers_and_mask(
|
359
|
+
tensor, intermediate_indices
|
360
|
+
)
|
361
|
+
other = CodeGenerator._generate_other(tensor)
|
362
|
+
|
363
|
+
return call("load", pointers, mask=mask, other=other).node
|
364
|
+
|
365
|
+
@staticmethod
|
366
|
+
def _generate_store(tensor, value, intermediate_indices=()):
|
367
|
+
pointers, mask = CodeGenerator._generate_pointers_and_mask(
|
368
|
+
tensor, intermediate_indices
|
369
|
+
)
|
370
|
+
|
371
|
+
return call("store", pointers, value, mask=mask).node
|
372
|
+
|
373
|
+
@staticmethod
|
374
|
+
def _generate_pointers_and_mask(tensor, intermediate_indices):
|
375
|
+
intermediate_offsets = CodeGenerator._generate_intermediate_offsets(
|
376
|
+
tensor, intermediate_indices
|
377
|
+
)
|
378
|
+
offsets = [
|
379
|
+
CodeGenerator._name_for_offsets(tensor, dim) + intermediate_offsets[dim]
|
380
|
+
for dim in range(tensor.original.ndim)
|
381
|
+
]
|
382
|
+
pointers = CodeGenerator._name_for_pointers(tensor) + sum(
|
383
|
+
map(lambda x, y: x * y, intermediate_offsets, tensor.original.strides)
|
384
|
+
)
|
385
|
+
mask = functools.reduce(
|
386
|
+
lambda x, y: x & y,
|
387
|
+
(
|
388
|
+
offs[CodeGenerator._generate_slices(tensor, dim)] < size
|
389
|
+
for dim, (offs, size) in enumerate(zip(offsets, tensor.original.shape))
|
390
|
+
),
|
391
|
+
)
|
392
|
+
|
393
|
+
return pointers, mask
|
394
|
+
|
395
|
+
@staticmethod
|
396
|
+
def _generate_other(tensor):
|
397
|
+
other = tensor.original.other
|
398
|
+
|
399
|
+
if isinstance(other, float) and not math.isfinite(other):
|
400
|
+
return f"float('{other}')"
|
401
|
+
|
402
|
+
return other
|
403
|
+
|
404
|
+
@staticmethod
|
405
|
+
def _generate_slices(tensor, dim):
|
406
|
+
return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
|
407
|
+
|
408
|
+
@staticmethod
|
409
|
+
def _generate_intermediate_offsets(tensor, intermediate_indices):
|
410
|
+
return tuple(
|
411
|
+
offs
|
412
|
+
for offs in tensor.offsets(
|
413
|
+
[0 for _ in range(tensor.ndim)]
|
414
|
+
+ list(intermediate_indices)
|
415
|
+
+ [0 for _ in range(tensor.inmost().ndim)]
|
416
|
+
)
|
337
417
|
)
|
338
418
|
|
419
|
+
@staticmethod
|
420
|
+
def _name_for_pointers(tensor):
|
421
|
+
return Symbol(f"{tensor.original.name}_pointers")
|
422
|
+
|
423
|
+
@staticmethod
|
424
|
+
def _name_for_offsets(tensor, dim):
|
425
|
+
return Symbol(f"{tensor.original.name}_offsets_{dim}")
|
426
|
+
|
339
427
|
|
340
428
|
class Tritonizer(ast.NodeTransformer):
|
341
429
|
def visit_Module(self, node):
|
@@ -10,7 +10,10 @@ def call(func, *args, **kwargs):
|
|
10
10
|
ast.Call(
|
11
11
|
func=attribute(func).node,
|
12
12
|
args=[Symbol(arg).node for arg in args],
|
13
|
-
keywords=[
|
13
|
+
keywords=[
|
14
|
+
ast.keyword(arg=kwarg, value=Symbol(kwargs[kwarg]).node)
|
15
|
+
for kwarg in kwargs
|
16
|
+
],
|
14
17
|
)
|
15
18
|
)
|
16
19
|
|
@@ -78,16 +78,36 @@ class Symbol:
|
|
78
78
|
)
|
79
79
|
|
80
80
|
def __mod__(self, other):
|
81
|
+
other = type(self)(other)
|
82
|
+
|
83
|
+
return type(self)(ast.BinOp(left=self._node, op=ast.Mod(), right=other._node))
|
84
|
+
|
85
|
+
def __lt__(self, other):
|
86
|
+
other = type(self)(other)
|
87
|
+
|
81
88
|
return type(self)(
|
82
|
-
ast.
|
89
|
+
ast.Compare(left=self._node, ops=[ast.Lt()], comparators=[other._node])
|
83
90
|
)
|
84
91
|
|
92
|
+
def __and__(self, other):
|
93
|
+
other = type(self)(other)
|
94
|
+
|
95
|
+
return type(self)(
|
96
|
+
ast.BinOp(left=self._node, op=ast.BitAnd(), right=other._node)
|
97
|
+
)
|
98
|
+
|
99
|
+
def __rand__(self, other):
|
100
|
+
return self.__and__(other)
|
101
|
+
|
85
102
|
def __getitem__(self, key):
|
86
103
|
return type(self)(ast.Subscript(value=self._node, slice=type(self)(key)._node))
|
87
104
|
|
88
105
|
def __repr__(self):
|
89
106
|
return ast.unparse(self._node)
|
90
107
|
|
108
|
+
def find_and_replace(self, target, replacement):
|
109
|
+
_FindAndReplacer(target.node, replacement.node).visit(self._node)
|
110
|
+
|
91
111
|
def names(self):
|
92
112
|
class NameCollector(ast.NodeVisitor):
|
93
113
|
def __init__(self):
|
@@ -117,16 +137,48 @@ class Symbol:
|
|
117
137
|
|
118
138
|
@staticmethod
|
119
139
|
def is_constexpr(name):
|
120
|
-
return name.startswith(
|
140
|
+
return name.startswith(Symbol._constexpr_prefix()) or Symbol.is_meta(name)
|
121
141
|
|
122
142
|
@staticmethod
|
123
143
|
def is_meta(name):
|
124
|
-
return name.startswith(
|
144
|
+
return name.startswith(Symbol._meta_prefix())
|
145
|
+
|
146
|
+
@staticmethod
|
147
|
+
def remove_prefix(name):
|
148
|
+
if name.startswith(Symbol._constexpr_prefix()):
|
149
|
+
return name.removeprefix(Symbol._constexpr_prefix())
|
150
|
+
|
151
|
+
if name.startswith(Symbol._meta_prefix()):
|
152
|
+
return name.removeprefix(Symbol._meta_prefix())
|
125
153
|
|
126
154
|
@staticmethod
|
127
155
|
def _create_constexpr(name):
|
128
|
-
return f"
|
156
|
+
return f"{Symbol._constexpr_prefix()}{name}"
|
129
157
|
|
130
158
|
@staticmethod
|
131
159
|
def _create_meta(name):
|
132
|
-
return f"
|
160
|
+
return f"{Symbol._meta_prefix()}{name}"
|
161
|
+
|
162
|
+
@staticmethod
|
163
|
+
def _constexpr_prefix():
|
164
|
+
return f"{Symbol._ninetoothed_prefix()}constexpr_"
|
165
|
+
|
166
|
+
@staticmethod
|
167
|
+
def _meta_prefix():
|
168
|
+
return f"{Symbol._ninetoothed_prefix()}meta_"
|
169
|
+
|
170
|
+
@staticmethod
|
171
|
+
def _ninetoothed_prefix():
|
172
|
+
return "_ninetoothed_"
|
173
|
+
|
174
|
+
|
175
|
+
class _FindAndReplacer(ast.NodeTransformer):
|
176
|
+
def __init__(self, target, replacement):
|
177
|
+
self._target_id = target.id
|
178
|
+
self._replacement = replacement
|
179
|
+
|
180
|
+
def visit_Name(self, node):
|
181
|
+
if node.id == self._target_id:
|
182
|
+
return self._replacement
|
183
|
+
|
184
|
+
return self.generic_visit(node)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import itertools
|
2
|
+
import re
|
2
3
|
|
3
4
|
from ninetoothed.language import call
|
4
5
|
from ninetoothed.symbol import Symbol
|
@@ -7,19 +8,24 @@ from ninetoothed.symbol import Symbol
|
|
7
8
|
class Tensor:
|
8
9
|
num_instances = 0
|
9
10
|
|
10
|
-
def __init__(
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
ndim=None,
|
14
|
+
shape=None,
|
15
|
+
dtype=None,
|
16
|
+
strides=None,
|
17
|
+
other=None,
|
18
|
+
original=None,
|
19
|
+
):
|
11
20
|
type(self).num_instances += 1
|
12
21
|
|
13
22
|
self.dtype = dtype
|
14
23
|
|
15
|
-
|
16
|
-
self.name = name
|
17
|
-
else:
|
18
|
-
self.name = f"tensor_{type(self).num_instances}"
|
24
|
+
self.name = f"tensor_{type(self).num_instances}"
|
19
25
|
|
20
26
|
if ndim is not None:
|
21
|
-
self.shape = [Symbol(
|
22
|
-
self.strides = [Symbol(
|
27
|
+
self.shape = [Symbol(self.size_string(i)) for i in range(ndim)]
|
28
|
+
self.strides = [Symbol(self.stride_string(i)) for i in range(ndim)]
|
23
29
|
else:
|
24
30
|
self.shape = shape
|
25
31
|
|
@@ -28,6 +34,13 @@ class Tensor:
|
|
28
34
|
else:
|
29
35
|
self.strides = self._calculate_default_strides(shape)
|
30
36
|
|
37
|
+
self.other = other
|
38
|
+
|
39
|
+
if original is not None:
|
40
|
+
self.original = original
|
41
|
+
else:
|
42
|
+
self.original = self
|
43
|
+
|
31
44
|
def tile(self, tile_shape, tile_strides=None):
|
32
45
|
if tile_strides is None:
|
33
46
|
tile_strides = [1 for _ in tile_shape]
|
@@ -59,10 +72,10 @@ class Tensor:
|
|
59
72
|
shape=inner_shape,
|
60
73
|
dtype=self.dtype,
|
61
74
|
strides=inner_strides,
|
62
|
-
|
75
|
+
original=self.original,
|
63
76
|
),
|
64
77
|
strides=outer_strides,
|
65
|
-
|
78
|
+
original=self.original,
|
66
79
|
)
|
67
80
|
|
68
81
|
def expand(self, shape):
|
@@ -77,12 +90,21 @@ class Tensor:
|
|
77
90
|
stride if new_size == -1 else 0
|
78
91
|
for new_size, stride in zip(shape, self.strides)
|
79
92
|
],
|
80
|
-
|
93
|
+
original=self.original,
|
94
|
+
)
|
95
|
+
|
96
|
+
def squeeze(self, dim):
|
97
|
+
# TODO: Add error handling.
|
98
|
+
return type(self)(
|
99
|
+
shape=[size for i, size in enumerate(self.shape) if dim != i],
|
100
|
+
dtype=self.dtype,
|
101
|
+
strides=[stride for i, stride in enumerate(self.strides) if dim != i],
|
102
|
+
original=self.original,
|
81
103
|
)
|
82
104
|
|
83
105
|
def names(self):
|
84
106
|
return (
|
85
|
-
{self.
|
107
|
+
{self.original.pointer_string()}
|
86
108
|
| {
|
87
109
|
name
|
88
110
|
for value in itertools.chain(self.shape, self.strides)
|
@@ -92,35 +114,31 @@ class Tensor:
|
|
92
114
|
| (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
|
93
115
|
)
|
94
116
|
|
95
|
-
def pointers(self, offsets=None):
|
96
|
-
if offsets is None:
|
97
|
-
offsets = self.offsets()
|
98
|
-
|
99
|
-
return self._pointer() + offsets
|
100
|
-
|
101
117
|
def offsets(self, indices=None):
|
102
118
|
if indices is None:
|
103
119
|
indices = self.indices()
|
104
120
|
|
105
|
-
|
106
|
-
if len(indices) != self.ndim():
|
107
|
-
raise IndexError("Incorrect number of indices.")
|
121
|
+
offsets = [[] for _ in range(self.original.ndim)]
|
108
122
|
|
109
|
-
|
110
|
-
|
111
|
-
* self.stride(idx)
|
112
|
-
* call("arange", 0, self.size(idx))[
|
113
|
-
tuple(slice(None) if i == idx else None for i in range(self.ndim()))
|
114
|
-
]
|
115
|
-
for idx in range(self.ndim())
|
116
|
-
)
|
123
|
+
curr = self
|
124
|
+
start = 0
|
117
125
|
|
118
|
-
|
119
|
-
|
126
|
+
while isinstance(curr, type(self)):
|
127
|
+
stop = start + curr.ndim
|
128
|
+
curr_indices = indices[start:stop]
|
120
129
|
|
121
|
-
|
122
|
-
|
123
|
-
|
130
|
+
for index, stride in zip(curr_indices, curr.strides):
|
131
|
+
for dim in self._dims_of(stride):
|
132
|
+
offsets[dim].append(index * stride)
|
133
|
+
|
134
|
+
start = stop
|
135
|
+
curr = curr.dtype
|
136
|
+
|
137
|
+
for dim in range(self.original.ndim):
|
138
|
+
offsets[dim] = sum(offsets[dim])
|
139
|
+
offsets[dim].find_and_replace(Symbol(self.original.strides[dim]), Symbol(1))
|
140
|
+
|
141
|
+
return offsets
|
124
142
|
|
125
143
|
def indices(self, index=None):
|
126
144
|
if index is None:
|
@@ -128,17 +146,22 @@ class Tensor:
|
|
128
146
|
|
129
147
|
indices = []
|
130
148
|
|
131
|
-
for stride in type(self)(shape=self.shape,
|
149
|
+
for stride in type(self)(shape=self.shape, original=self.original).strides:
|
132
150
|
indices.append(index // stride)
|
133
151
|
index %= stride
|
134
152
|
|
135
153
|
curr = self.dtype
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
154
|
+
|
155
|
+
while isinstance(curr.dtype, type(self)):
|
156
|
+
for _ in range(curr.ndim):
|
157
|
+
indices.append(0)
|
158
|
+
|
140
159
|
curr = curr.dtype
|
141
160
|
|
161
|
+
if isinstance(curr, type(self)):
|
162
|
+
for dim in range(curr.ndim):
|
163
|
+
indices.append(call("arange", 0, curr.shape[dim]))
|
164
|
+
|
142
165
|
return tuple(indices)
|
143
166
|
|
144
167
|
def inmost(self):
|
@@ -147,8 +170,14 @@ class Tensor:
|
|
147
170
|
|
148
171
|
return self.dtype.inmost()
|
149
172
|
|
150
|
-
def
|
151
|
-
return
|
173
|
+
def pointer_string(self):
|
174
|
+
return f"{self.name}_pointer"
|
175
|
+
|
176
|
+
def size_string(self, dim):
|
177
|
+
return f"{self.name}_size_{dim}"
|
178
|
+
|
179
|
+
def stride_string(self, dim):
|
180
|
+
return f"{self.name}_stride_{dim}"
|
152
181
|
|
153
182
|
def size(self, dim=None):
|
154
183
|
if dim is None:
|
@@ -162,12 +191,31 @@ class Tensor:
|
|
162
191
|
|
163
192
|
return self.strides[dim]
|
164
193
|
|
194
|
+
@property
|
195
|
+
def ndim(self):
|
196
|
+
return len(self.shape)
|
197
|
+
|
165
198
|
@staticmethod
|
166
|
-
def
|
167
|
-
return
|
199
|
+
def pointer_pattern():
|
200
|
+
return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
|
168
201
|
|
169
|
-
|
170
|
-
|
202
|
+
@staticmethod
|
203
|
+
def size_pattern():
|
204
|
+
return re.compile(rf"({_identifier_pattern_raw_string()})_(size)_(.+)")
|
205
|
+
|
206
|
+
@staticmethod
|
207
|
+
def stride_pattern():
|
208
|
+
return re.compile(rf"({_identifier_pattern_raw_string()})_(stride)_(.+)")
|
209
|
+
|
210
|
+
def _dims_of(self, stride):
|
211
|
+
dims = set()
|
212
|
+
names = stride.names() if isinstance(stride, Symbol) else {stride}
|
213
|
+
|
214
|
+
for dim, original_stride in enumerate(self.original.strides):
|
215
|
+
if str(original_stride) in names:
|
216
|
+
dims.add(dim)
|
217
|
+
|
218
|
+
return dims
|
171
219
|
|
172
220
|
@staticmethod
|
173
221
|
def _calculate_default_strides(shape):
|
@@ -177,3 +225,7 @@ class Tensor:
|
|
177
225
|
strides.append(size * strides[-1])
|
178
226
|
|
179
227
|
return reversed(strides)
|
228
|
+
|
229
|
+
|
230
|
+
def _identifier_pattern_raw_string():
|
231
|
+
return r"[a-zA-Z_][a-zA-Z0-9_]*"
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import ast
|
2
|
+
|
3
|
+
from ninetoothed.tensor import Tensor
|
4
|
+
|
5
|
+
|
6
|
+
class Torchifier(ast.NodeTransformer):
|
7
|
+
def visit_Name(self, node):
|
8
|
+
self.generic_visit(node)
|
9
|
+
|
10
|
+
source = node.id
|
11
|
+
|
12
|
+
def repl(match):
|
13
|
+
return f"{match.group(1)}"
|
14
|
+
|
15
|
+
source = Tensor.pointer_pattern().sub(repl, source)
|
16
|
+
|
17
|
+
def repl(match):
|
18
|
+
return f"{match.group(1)}.{match.group(2)}({match.group(3)})"
|
19
|
+
|
20
|
+
source = Tensor.size_pattern().sub(repl, source)
|
21
|
+
source = Tensor.stride_pattern().sub(repl, source)
|
22
|
+
|
23
|
+
if source != node.id:
|
24
|
+
return ast.parse(source, mode="eval").body
|
25
|
+
|
26
|
+
return node
|
27
|
+
|
28
|
+
def visit_Attribute(self, node):
|
29
|
+
self.generic_visit(node)
|
30
|
+
|
31
|
+
if (
|
32
|
+
isinstance(node.value, ast.Name)
|
33
|
+
and node.value.id == "ninetoothed"
|
34
|
+
and node.attr == "language"
|
35
|
+
):
|
36
|
+
return node.value
|
37
|
+
|
38
|
+
return node
|
@@ -19,18 +19,21 @@ def matmul(lhs, rhs):
|
|
19
19
|
.tile((1, -1))
|
20
20
|
.expand((-1, output_tiled.shape[1]))
|
21
21
|
)
|
22
|
+
lhs_tiled.dtype = lhs_tiled.dtype.squeeze(0)
|
23
|
+
|
22
24
|
rhs_tiled = (
|
23
25
|
Tensor(2)
|
24
26
|
.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
|
25
27
|
.tile((-1, 1))
|
26
28
|
.expand((output_tiled.shape[0], -1))
|
27
29
|
)
|
30
|
+
rhs_tiled.dtype = rhs_tiled.dtype.squeeze(1)
|
28
31
|
|
29
32
|
@ninetoothed.jit
|
30
33
|
def matmul_kernel(lhs: lhs_tiled, rhs: rhs_tiled, output: output_tiled):
|
31
34
|
accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
|
32
|
-
for k in range(lhs.shape[
|
33
|
-
accumulator
|
35
|
+
for k in range(lhs.shape[0]):
|
36
|
+
accumulator += ntl.dot(lhs[k], rhs[k])
|
34
37
|
output = accumulator.to(ntl.float16)
|
35
38
|
|
36
39
|
output = torch.empty(
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
|
4
|
+
import ninetoothed
|
5
|
+
import ninetoothed.language as ntl
|
6
|
+
from ninetoothed import Symbol, Tensor
|
7
|
+
from tests.skippers import skip_if_cuda_not_available
|
8
|
+
|
9
|
+
|
10
|
+
def softmax(input):
|
11
|
+
BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True)
|
12
|
+
|
13
|
+
@ninetoothed.jit
|
14
|
+
def softmax_kernel(
|
15
|
+
input_row: Tensor(2, other=float("-inf")).tile((1, BLOCK_SIZE)),
|
16
|
+
output_row: Tensor(2).tile((1, BLOCK_SIZE)),
|
17
|
+
):
|
18
|
+
row_minus_max = input_row - ntl.max(input_row)
|
19
|
+
numerator = ntl.exp(row_minus_max)
|
20
|
+
denominator = ntl.sum(numerator)
|
21
|
+
output_row = numerator / denominator # noqa: F841
|
22
|
+
|
23
|
+
output = torch.empty_like(input)
|
24
|
+
|
25
|
+
softmax_kernel(input, output, BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]))
|
26
|
+
|
27
|
+
return output
|
28
|
+
|
29
|
+
|
30
|
+
@skip_if_cuda_not_available
|
31
|
+
class TestCUDA:
|
32
|
+
@classmethod
|
33
|
+
def setup_class(cls):
|
34
|
+
torch.manual_seed(0)
|
35
|
+
|
36
|
+
cls.input = torch.randn(1823, 781, device="cuda")
|
37
|
+
|
38
|
+
def test_fp32(self):
|
39
|
+
input = type(self).input.to(torch.float32)
|
40
|
+
|
41
|
+
assert torch.allclose(softmax(input), torch.softmax(input, axis=-1))
|
@@ -1,34 +0,0 @@
|
|
1
|
-
import ast
|
2
|
-
import re
|
3
|
-
|
4
|
-
|
5
|
-
class Torchifier(ast.NodeTransformer):
|
6
|
-
def visit_Name(self, node):
|
7
|
-
self.generic_visit(node)
|
8
|
-
|
9
|
-
pattern = re.compile(r"([a-zA-Z_][a-zA-Z0-9_]*)_(size|stride)_(.+)")
|
10
|
-
|
11
|
-
node.id = node.id.replace("_ptr", "")
|
12
|
-
|
13
|
-
if re.fullmatch(pattern, node.id):
|
14
|
-
return ast.parse(
|
15
|
-
pattern.sub(
|
16
|
-
lambda match: f"{match.group(1)}.{match.group(2)}({match.group(3)})",
|
17
|
-
node.id,
|
18
|
-
),
|
19
|
-
mode="eval",
|
20
|
-
).body
|
21
|
-
|
22
|
-
return node
|
23
|
-
|
24
|
-
def visit_Attribute(self, node):
|
25
|
-
self.generic_visit(node)
|
26
|
-
|
27
|
-
if (
|
28
|
-
isinstance(node.value, ast.Name)
|
29
|
-
and node.value.id == "ninetoothed"
|
30
|
-
and node.attr == "language"
|
31
|
-
):
|
32
|
-
return node.value
|
33
|
-
|
34
|
-
return node
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|