ninetoothed 0.11.0__tar.gz → 0.11.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed-0.11.1/.gitattributes +1 -0
- ninetoothed-0.11.1/.github/workflows/sphinx.yml +37 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/PKG-INFO +1 -1
- ninetoothed-0.11.1/docs/Makefile +20 -0
- ninetoothed-0.11.1/docs/make.bat +35 -0
- ninetoothed-0.11.1/docs/requirements.txt +2 -0
- ninetoothed-0.11.1/docs/source/_static/matmul-tiling.png +3 -0
- ninetoothed-0.11.1/docs/source/_static/ninetoothed-logo.png +3 -0
- ninetoothed-0.11.1/docs/source/_static/vecadd-tiling.png +3 -0
- ninetoothed-0.11.1/docs/source/code_generation.rst +9 -0
- ninetoothed-0.11.1/docs/source/conf.py +28 -0
- ninetoothed-0.11.1/docs/source/index.rst +14 -0
- ninetoothed-0.11.1/docs/source/installation.rst +12 -0
- ninetoothed-0.11.1/docs/source/python_api.rst +9 -0
- ninetoothed-0.11.1/docs/source/symbol.rst +4 -0
- ninetoothed-0.11.1/docs/source/tensor.rst +18 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/pyproject.toml +1 -1
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/src/ninetoothed/jit.py +31 -3
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/src/ninetoothed/symbol.py +7 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/src/ninetoothed/tensor.py +63 -0
- ninetoothed-0.11.1/tests/test_attention.py +92 -0
- ninetoothed-0.11.1/tests/test_conv2d.py +62 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/tests/test_matmul.py +18 -12
- ninetoothed-0.11.0/docs/source/_static/matmul-tiling.png +0 -0
- ninetoothed-0.11.0/docs/source/_static/ninetoothed-logo.png +0 -0
- ninetoothed-0.11.0/docs/source/_static/vecadd-tiling.png +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/.github/workflows/publish-to-pypi.yml +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/.github/workflows/pytest.yml +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/.github/workflows/ruff.yml +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/.gitignore +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/LICENSE +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/README.md +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/docs/README.zh.md +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/requirements.txt +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/src/ninetoothed/__init__.py +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/src/ninetoothed/language.py +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/src/ninetoothed/naming.py +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/src/ninetoothed/torchifier.py +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/tests/__init__.py +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/tests/skippers.py +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/tests/test_add.py +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/tests/test_addmm.py +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/tests/test_naming.py +0 -0
- {ninetoothed-0.11.0 → ninetoothed-0.11.1}/tests/test_softmax.py +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
*.png filter=lfs diff=lfs merge=lfs -text
|
@@ -0,0 +1,37 @@
|
|
1
|
+
name: "Sphinx: Render docs"
|
2
|
+
|
3
|
+
on: push
|
4
|
+
|
5
|
+
jobs:
|
6
|
+
build:
|
7
|
+
runs-on: ubuntu-latest
|
8
|
+
permissions:
|
9
|
+
contents: write
|
10
|
+
steps:
|
11
|
+
- uses: actions/checkout@v4
|
12
|
+
with:
|
13
|
+
persist-credentials: false
|
14
|
+
lfs: true
|
15
|
+
- name: Set up Python
|
16
|
+
uses: actions/setup-python@v5
|
17
|
+
with:
|
18
|
+
python-version: "3.10"
|
19
|
+
- name: Install dependencies
|
20
|
+
run: |
|
21
|
+
python -m pip install --upgrade pip
|
22
|
+
pip install .
|
23
|
+
pip install -r docs/requirements.txt
|
24
|
+
- name: Build HTML
|
25
|
+
run: make -C docs html
|
26
|
+
- name: Upload artifacts
|
27
|
+
uses: actions/upload-artifact@v4
|
28
|
+
with:
|
29
|
+
name: html-docs
|
30
|
+
path: docs/build/html/
|
31
|
+
- name: Deploy
|
32
|
+
uses: peaceiris/actions-gh-pages@v3
|
33
|
+
if: github.ref == 'refs/heads/master'
|
34
|
+
with:
|
35
|
+
github_token: ${{ secrets.GITHUB_TOKEN }}
|
36
|
+
publish_dir: docs/build/html
|
37
|
+
cname: ninetoothed.org
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.11.
|
3
|
+
Version: 0.11.1
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Minimal makefile for Sphinx documentation
|
2
|
+
#
|
3
|
+
|
4
|
+
# You can set these variables from the command line, and also
|
5
|
+
# from the environment for the first two.
|
6
|
+
SPHINXOPTS ?=
|
7
|
+
SPHINXBUILD ?= sphinx-build
|
8
|
+
SOURCEDIR = source
|
9
|
+
BUILDDIR = build
|
10
|
+
|
11
|
+
# Put it first so that "make" without argument is like "make help".
|
12
|
+
help:
|
13
|
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
14
|
+
|
15
|
+
.PHONY: help Makefile
|
16
|
+
|
17
|
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
18
|
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
19
|
+
%: Makefile
|
20
|
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
@@ -0,0 +1,35 @@
|
|
1
|
+
@ECHO OFF
|
2
|
+
|
3
|
+
pushd %~dp0
|
4
|
+
|
5
|
+
REM Command file for Sphinx documentation
|
6
|
+
|
7
|
+
if "%SPHINXBUILD%" == "" (
|
8
|
+
set SPHINXBUILD=sphinx-build
|
9
|
+
)
|
10
|
+
set SOURCEDIR=source
|
11
|
+
set BUILDDIR=build
|
12
|
+
|
13
|
+
%SPHINXBUILD% >NUL 2>NUL
|
14
|
+
if errorlevel 9009 (
|
15
|
+
echo.
|
16
|
+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
17
|
+
echo.installed, then set the SPHINXBUILD environment variable to point
|
18
|
+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
19
|
+
echo.may add the Sphinx directory to PATH.
|
20
|
+
echo.
|
21
|
+
echo.If you don't have Sphinx installed, grab it from
|
22
|
+
echo.https://www.sphinx-doc.org/
|
23
|
+
exit /b 1
|
24
|
+
)
|
25
|
+
|
26
|
+
if "%1" == "" goto help
|
27
|
+
|
28
|
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
29
|
+
goto end
|
30
|
+
|
31
|
+
:help
|
32
|
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
33
|
+
|
34
|
+
:end
|
35
|
+
popd
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# Configuration file for the Sphinx documentation builder.
|
2
|
+
#
|
3
|
+
# For the full list of built-in configuration values, see the documentation:
|
4
|
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
5
|
+
|
6
|
+
# -- Project information -----------------------------------------------------
|
7
|
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
8
|
+
|
9
|
+
project = "NineToothed"
|
10
|
+
copyright = "2024, NineToothed Contributors"
|
11
|
+
author = "NineToothed Contributors"
|
12
|
+
|
13
|
+
# -- General configuration ---------------------------------------------------
|
14
|
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
15
|
+
|
16
|
+
extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosummary"]
|
17
|
+
|
18
|
+
templates_path = ["_templates"]
|
19
|
+
exclude_patterns = []
|
20
|
+
|
21
|
+
|
22
|
+
# -- Options for HTML output -------------------------------------------------
|
23
|
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
24
|
+
|
25
|
+
html_theme = "pydata_sphinx_theme"
|
26
|
+
html_static_path = ["_static"]
|
27
|
+
html_title = "NineToothed"
|
28
|
+
html_logo = "_static/ninetoothed-logo.png"
|
@@ -0,0 +1,14 @@
|
|
1
|
+
NineToothed Documentation
|
2
|
+
=========================
|
3
|
+
|
4
|
+
**NineToothed** is a domain-specific language (DSL) based on Triton, offering higher-level abstractions. Through its tensor-oriented metaprogramming (TOM) model, it empowers developers to write high-performance compute kernels intuitively, without the need to manage low-level details like pointer arithmetic or memory access.
|
5
|
+
|
6
|
+
.. note::
|
7
|
+
|
8
|
+
This project is under active development.
|
9
|
+
|
10
|
+
.. toctree::
|
11
|
+
:maxdepth: 2
|
12
|
+
|
13
|
+
installation
|
14
|
+
python_api
|
@@ -0,0 +1,12 @@
|
|
1
|
+
Installation
|
2
|
+
============
|
3
|
+
|
4
|
+
You can install NineToothed using ``pip``:
|
5
|
+
|
6
|
+
.. code-block::
|
7
|
+
|
8
|
+
pip install ninetoothed
|
9
|
+
|
10
|
+
To fully leverage its capabilities, you will also need to install a compatible deep learning framework. Currently, NineToothed supports `PyTorch <https://pytorch.org/>`_.
|
11
|
+
|
12
|
+
It is generally considered good practice to use a virtual environment when installing packages with pip, though it is optional. You may find this `documentation <https://docs.python.org/3/library/venv.html>`_ helpful.
|
@@ -0,0 +1,18 @@
|
|
1
|
+
Tensor
|
2
|
+
======
|
3
|
+
|
4
|
+
.. autoclass:: ninetoothed.Tensor
|
5
|
+
|
6
|
+
Meta-Operations
|
7
|
+
---------------
|
8
|
+
|
9
|
+
.. autosummary::
|
10
|
+
:toctree: generated
|
11
|
+
:nosignatures:
|
12
|
+
|
13
|
+
ninetoothed.Tensor.tile
|
14
|
+
ninetoothed.Tensor.expand
|
15
|
+
ninetoothed.Tensor.squeeze
|
16
|
+
ninetoothed.Tensor.permute
|
17
|
+
ninetoothed.Tensor.flatten
|
18
|
+
ninetoothed.Tensor.ravel
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "ninetoothed"
|
7
|
-
version = "0.11.
|
7
|
+
version = "0.11.1"
|
8
8
|
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
|
9
9
|
description = "A domain-specific language based on Triton but providing higher-level abstraction."
|
10
10
|
readme = "README.md"
|
@@ -20,6 +20,13 @@ from ninetoothed.torchifier import Torchifier
|
|
20
20
|
|
21
21
|
|
22
22
|
def make(arrangement, application, tensors):
|
23
|
+
"""Integrate the arrangement and the application of the tensors.
|
24
|
+
|
25
|
+
:param arrangement: The arrangement of the tensors.
|
26
|
+
:param application: The application of the tensors.
|
27
|
+
:param tensors: The tensors.
|
28
|
+
:return: A handle to the compute kernel.
|
29
|
+
"""
|
23
30
|
params = inspect.signature(application).parameters
|
24
31
|
types = arrangement(*tensors)
|
25
32
|
annotations = {param: type for param, type in zip(params, types)}
|
@@ -28,14 +35,26 @@ def make(arrangement, application, tensors):
|
|
28
35
|
return jit(application)
|
29
36
|
|
30
37
|
|
31
|
-
def jit(
|
38
|
+
def jit(func=None, *, _prettify=False):
|
39
|
+
"""A decorator for generating compute kernels.
|
40
|
+
|
41
|
+
:param func: The function to be compiled.
|
42
|
+
:param _prettify: Whether to prettify the generated code.
|
43
|
+
:return: A handle to the compute kernel.
|
44
|
+
|
45
|
+
.. note::
|
46
|
+
|
47
|
+
The ``_prettify`` parameter is experimental, which might break
|
48
|
+
the generated code.
|
49
|
+
"""
|
50
|
+
|
32
51
|
def wrapper(func):
|
33
52
|
return JIT(func, _prettify=_prettify)()
|
34
53
|
|
35
|
-
if
|
54
|
+
if func is None:
|
36
55
|
return wrapper
|
37
56
|
|
38
|
-
return wrapper(
|
57
|
+
return wrapper(func)
|
39
58
|
|
40
59
|
|
41
60
|
class JIT:
|
@@ -472,6 +491,15 @@ class CodeGenerator(ast.NodeTransformer):
|
|
472
491
|
for target_dim in range(tensor.target.ndim)
|
473
492
|
if offsets[source_dim][target_dim] != 0
|
474
493
|
),
|
494
|
+
) & functools.reduce(
|
495
|
+
lambda x, y: x & y,
|
496
|
+
(
|
497
|
+
indices[dim - tensor.innermost().target.ndim][
|
498
|
+
type(self)._generate_slices(tensor, target_dim)
|
499
|
+
]
|
500
|
+
< tensor.innermost().target.shape[dim]
|
501
|
+
for dim, target_dim in enumerate(tensor.innermost().target_dims)
|
502
|
+
),
|
475
503
|
)
|
476
504
|
|
477
505
|
return pointers, mask
|
@@ -7,6 +7,13 @@ import ninetoothed.naming as naming
|
|
7
7
|
|
8
8
|
|
9
9
|
class Symbol:
|
10
|
+
"""A class uesed to represent a symbol.
|
11
|
+
|
12
|
+
:param expr: The expression used to construct the symbol.
|
13
|
+
:param constexpr: Whether the symbol is a constexpr.
|
14
|
+
:param mata: Whether the symbol is a meta.
|
15
|
+
"""
|
16
|
+
|
10
17
|
def __init__(self, expr, constexpr=None, meta=None):
|
11
18
|
if isinstance(expr, type(self)):
|
12
19
|
self._node = expr._node
|
@@ -8,6 +8,21 @@ from ninetoothed.symbol import Symbol
|
|
8
8
|
|
9
9
|
|
10
10
|
class Tensor:
|
11
|
+
"""A class uesed to represent a symbolic tensor.
|
12
|
+
|
13
|
+
:param ndim: The number of dimensions of the tensor.
|
14
|
+
:param shape: The shape of the tensor.
|
15
|
+
:param dtype: The element type of the tensor.
|
16
|
+
:param strides: The strides of the tensor.
|
17
|
+
:param other: The values for out-of-bounds positions.
|
18
|
+
:param constexpr_shape: Whether the sizes are constexpr.
|
19
|
+
:param name: The name of the tensor.
|
20
|
+
:param source: For internal use only.
|
21
|
+
:param source_dims: For internal use only.
|
22
|
+
:param target: For internal use only.
|
23
|
+
:param target_dims: For internal use only.
|
24
|
+
"""
|
25
|
+
|
11
26
|
num_instances = 0
|
12
27
|
|
13
28
|
def __init__(
|
@@ -70,6 +85,14 @@ class Tensor:
|
|
70
85
|
type(self).num_instances += 1
|
71
86
|
|
72
87
|
def tile(self, tile_shape, strides=None, dilation=None):
|
88
|
+
"""Tiles the tensor into a hierarchical tensor.
|
89
|
+
|
90
|
+
:param tile_shape: The shape of a tile.
|
91
|
+
:param strides: The interval at which each tile is generated.
|
92
|
+
:param dilation: The spacing between tiles.
|
93
|
+
:return: A hierarchical tensor.
|
94
|
+
"""
|
95
|
+
|
73
96
|
if strides is None:
|
74
97
|
strides = [-1 for _ in tile_shape]
|
75
98
|
|
@@ -119,6 +142,12 @@ class Tensor:
|
|
119
142
|
)
|
120
143
|
|
121
144
|
def expand(self, shape):
|
145
|
+
"""Expands the specified singleton dimensions of the tensor.
|
146
|
+
|
147
|
+
:param shape: The expanded shape.
|
148
|
+
:return: The expanded tensor.
|
149
|
+
"""
|
150
|
+
|
122
151
|
# TODO: Add error handling.
|
123
152
|
return type(self)(
|
124
153
|
shape=[
|
@@ -136,6 +165,12 @@ class Tensor:
|
|
136
165
|
)
|
137
166
|
|
138
167
|
def squeeze(self, dim):
|
168
|
+
"""Removes the specified singleton dimensions of the tensor.
|
169
|
+
|
170
|
+
:param dim: The dimension(s) to be squeezed.
|
171
|
+
:return: The squeezed tensor.
|
172
|
+
"""
|
173
|
+
|
139
174
|
if not isinstance(dim, tuple):
|
140
175
|
dim = (dim,)
|
141
176
|
|
@@ -158,6 +193,12 @@ class Tensor:
|
|
158
193
|
)
|
159
194
|
|
160
195
|
def permute(self, dims):
|
196
|
+
"""Permutes the dimensions of the tensor.
|
197
|
+
|
198
|
+
:param dims: The permuted ordering of the dimensions.
|
199
|
+
:return: The permuted tensor.
|
200
|
+
"""
|
201
|
+
|
161
202
|
# TODO: Add error handling.
|
162
203
|
new_shape = [None for _ in range(self.ndim)]
|
163
204
|
new_strides = [None for _ in range(self.ndim)]
|
@@ -178,6 +219,16 @@ class Tensor:
|
|
178
219
|
)
|
179
220
|
|
180
221
|
def flatten(self, start_dim=None, end_dim=None):
|
222
|
+
"""Flattens the specified dimensions of the tensor.
|
223
|
+
|
224
|
+
See :func:`ravel` for the differences between :func:`flatten`
|
225
|
+
and :func:`ravel`.
|
226
|
+
|
227
|
+
:param start_dim: The first dimension to flatten.
|
228
|
+
:param end_dim: The dimension after the last to flatten.
|
229
|
+
:return: The flattened tensor.
|
230
|
+
"""
|
231
|
+
|
181
232
|
# TODO: Add error handling.
|
182
233
|
if start_dim is None:
|
183
234
|
start_dim = 0
|
@@ -222,6 +273,18 @@ class Tensor:
|
|
222
273
|
)
|
223
274
|
|
224
275
|
def ravel(self):
|
276
|
+
"""Flattens the hierarchy of the tensor.
|
277
|
+
|
278
|
+
:func:`ravel` differs from :func:`flatten`, which only flattens
|
279
|
+
dimensions at a single level. For example, consider a tensor
|
280
|
+
with two levels: the first level has a shape of ``(N, P, Q)``,
|
281
|
+
and the second level has a shape of ``(C, R, S)``. After
|
282
|
+
applying :func:`ravel`, the resulting tensor will have a single
|
283
|
+
flattened level with a shape of ``(N, P, Q, C, R, S)``.
|
284
|
+
|
285
|
+
:return: The raveled tensor.
|
286
|
+
"""
|
287
|
+
|
225
288
|
# TODO: Add error handling.
|
226
289
|
new_shape = []
|
227
290
|
new_strides = []
|
@@ -0,0 +1,92 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
|
4
|
+
import ninetoothed
|
5
|
+
import ninetoothed.language as ntl
|
6
|
+
from ninetoothed import Symbol, Tensor
|
7
|
+
from tests.skippers import skip_if_cuda_not_available
|
8
|
+
|
9
|
+
|
10
|
+
def arrangement(q, k, v, o):
|
11
|
+
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", constexpr=True)
|
12
|
+
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", constexpr=True)
|
13
|
+
|
14
|
+
def arrange_q_or_o(input):
|
15
|
+
arranged = input.tile((1, 1, BLOCK_SIZE_M, -1))
|
16
|
+
arranged.dtype = arranged.dtype.squeeze((0, 1))
|
17
|
+
|
18
|
+
return arranged
|
19
|
+
|
20
|
+
def arrange_k_or_v(input):
|
21
|
+
arranged = (
|
22
|
+
input.tile((1, 1, BLOCK_SIZE_N, -1))
|
23
|
+
.tile((1, 1, -1, -1))
|
24
|
+
.expand((-1, -1, q_arranged.shape[-2], -1))
|
25
|
+
)
|
26
|
+
arranged.dtype = arranged.dtype.squeeze((0, 1, 3))
|
27
|
+
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
|
28
|
+
|
29
|
+
return arranged
|
30
|
+
|
31
|
+
q_arranged = arrange_q_or_o(q)
|
32
|
+
|
33
|
+
return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), arrange_q_or_o(o)
|
34
|
+
|
35
|
+
|
36
|
+
def application(q, k, v, o):
|
37
|
+
q_loaded = (q * 1.44269504089).to(ntl.float16)
|
38
|
+
|
39
|
+
acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
|
40
|
+
l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32)
|
41
|
+
m_i = ntl.full((q.shape[-2],), float("-inf"), dtype=ntl.float32)
|
42
|
+
|
43
|
+
for i in range(k.shape[0]):
|
44
|
+
qk = ntl.dot(q_loaded, ntl.trans(k[i]))
|
45
|
+
|
46
|
+
m_ij = ntl.maximum(m_i, ntl.max(qk, 1))
|
47
|
+
p = ntl.exp2(qk - m_ij[:, None])
|
48
|
+
l_ij = ntl.sum(p, 1)
|
49
|
+
|
50
|
+
alpha = ntl.exp2(m_i - m_ij)
|
51
|
+
acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i])
|
52
|
+
m_i = m_ij
|
53
|
+
l_i = l_i * alpha + l_ij
|
54
|
+
|
55
|
+
acc /= l_i[:, None]
|
56
|
+
o = acc # noqa: F841
|
57
|
+
|
58
|
+
|
59
|
+
def attention(q, k, v):
|
60
|
+
o = torch.empty_like(q, dtype=v.dtype)
|
61
|
+
|
62
|
+
attention_kernel = ninetoothed.make(
|
63
|
+
arrangement, application, (Tensor(4, constexpr_shape=True) for _ in range(4))
|
64
|
+
)
|
65
|
+
|
66
|
+
attention_kernel(q, k, v, o, BLOCK_SIZE_M=128, BLOCK_SIZE_N=64)
|
67
|
+
|
68
|
+
return o
|
69
|
+
|
70
|
+
|
71
|
+
@skip_if_cuda_not_available
|
72
|
+
class TestCUDA:
|
73
|
+
@classmethod
|
74
|
+
def setup_class(cls):
|
75
|
+
torch.manual_seed(0)
|
76
|
+
|
77
|
+
shape = (2, 4, 1024, 64)
|
78
|
+
|
79
|
+
cls.q = torch.randn(shape, device="cuda")
|
80
|
+
cls.k = torch.randn(shape, device="cuda")
|
81
|
+
cls.v = torch.randn(shape, device="cuda")
|
82
|
+
|
83
|
+
def test_fp16(self):
|
84
|
+
q = type(self).q.to(torch.float16)
|
85
|
+
k = type(self).k.to(torch.float16)
|
86
|
+
v = type(self).v.to(torch.float16)
|
87
|
+
|
88
|
+
assert torch.allclose(
|
89
|
+
attention(q, k, v),
|
90
|
+
F.scaled_dot_product_attention(q, k, v, scale=1),
|
91
|
+
atol=0.01,
|
92
|
+
)
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
|
4
|
+
import ninetoothed
|
5
|
+
import tests.test_matmul as matmul
|
6
|
+
from ninetoothed import Tensor
|
7
|
+
from tests.skippers import skip_if_cuda_not_available
|
8
|
+
|
9
|
+
|
10
|
+
def arrangement(input, filter, output):
|
11
|
+
input_tiled = input.tile((1, *filter.shape[1:]), strides=(-1, -1, 1, 1))
|
12
|
+
input_squeezed = input_tiled.squeeze(1)
|
13
|
+
input_squeezed.dtype = input_squeezed.dtype.squeeze(0)
|
14
|
+
input_raveled = input_squeezed.ravel()
|
15
|
+
input_flattened = input_raveled.flatten(end_dim=3).flatten(start_dim=1)
|
16
|
+
|
17
|
+
filter_flattened = filter.flatten(start_dim=1)
|
18
|
+
filter_permuted = filter_flattened.permute((1, 0))
|
19
|
+
|
20
|
+
output_flattened = output.permute((0, 2, 3, 1)).flatten(end_dim=3)
|
21
|
+
|
22
|
+
return matmul.arrangement(input_flattened, filter_permuted, output_flattened)
|
23
|
+
|
24
|
+
|
25
|
+
def conv2d(input, filter):
|
26
|
+
n, _, h, w = input.shape
|
27
|
+
k, _, r, s = filter.shape
|
28
|
+
p = h - r + 1
|
29
|
+
q = w - s + 1
|
30
|
+
|
31
|
+
output = torch.empty((n, k, p, q), device=input.device, dtype=input.dtype)
|
32
|
+
|
33
|
+
conv2d_kernel = ninetoothed.make(
|
34
|
+
arrangement,
|
35
|
+
matmul.application,
|
36
|
+
(Tensor(4), Tensor(4, constexpr_shape=True), Tensor(4)),
|
37
|
+
)
|
38
|
+
|
39
|
+
conv2d_kernel(input, filter, output)
|
40
|
+
|
41
|
+
return output
|
42
|
+
|
43
|
+
|
44
|
+
@skip_if_cuda_not_available
|
45
|
+
class TestCUDA:
|
46
|
+
@classmethod
|
47
|
+
def setup_class(cls):
|
48
|
+
torch.manual_seed(0)
|
49
|
+
|
50
|
+
n, c, h, w = 4, 64, 16, 16
|
51
|
+
k, _, r, s = 512, c, 3, 3
|
52
|
+
|
53
|
+
cls.input = torch.randn(n, c, h, w, device="cuda")
|
54
|
+
cls.filter = torch.randn(k, c, r, s, device="cuda")
|
55
|
+
|
56
|
+
def test_fp16(self):
|
57
|
+
input = type(self).input.to(torch.float16)
|
58
|
+
filter = type(self).filter.to(torch.float16)
|
59
|
+
|
60
|
+
assert torch.allclose(
|
61
|
+
conv2d(input, filter), F.conv2d(input, filter), atol=0.001, rtol=0.001
|
62
|
+
)
|
@@ -6,40 +6,46 @@ from ninetoothed import Symbol, Tensor
|
|
6
6
|
from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
|
7
7
|
|
8
8
|
|
9
|
-
def
|
9
|
+
def arrangement(lhs, rhs, output):
|
10
10
|
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
|
11
11
|
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
|
12
12
|
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)
|
13
13
|
|
14
|
-
output_tiled =
|
14
|
+
output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
|
15
15
|
|
16
16
|
lhs_tiled = (
|
17
|
-
|
18
|
-
.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
|
17
|
+
lhs.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
|
19
18
|
.tile((1, -1))
|
20
19
|
.expand((-1, output_tiled.shape[1]))
|
21
20
|
)
|
22
21
|
lhs_tiled.dtype = lhs_tiled.dtype.squeeze(0)
|
23
22
|
|
24
23
|
rhs_tiled = (
|
25
|
-
|
26
|
-
.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
|
24
|
+
rhs.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
|
27
25
|
.tile((-1, 1))
|
28
26
|
.expand((output_tiled.shape[0], -1))
|
29
27
|
)
|
30
28
|
rhs_tiled.dtype = rhs_tiled.dtype.squeeze(1)
|
31
29
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
30
|
+
return lhs_tiled, rhs_tiled, output_tiled
|
31
|
+
|
32
|
+
|
33
|
+
def application(lhs, rhs, output):
|
34
|
+
accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
|
35
|
+
for k in range(lhs.shape[0]):
|
36
|
+
accumulator += ntl.dot(lhs[k], rhs[k])
|
37
|
+
output = accumulator.to(ntl.float16)
|
38
|
+
|
38
39
|
|
40
|
+
def matmul(lhs, rhs):
|
39
41
|
output = torch.empty(
|
40
42
|
(lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16
|
41
43
|
)
|
42
44
|
|
45
|
+
matmul_kernel = ninetoothed.make(
|
46
|
+
arrangement, application, (Tensor(2), Tensor(2), Tensor(2))
|
47
|
+
)
|
48
|
+
|
43
49
|
matmul_kernel(lhs, rhs, output)
|
44
50
|
|
45
51
|
return output
|
Binary file
|
Binary file
|
Binary file
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|