ninetoothed 0.10.0__tar.gz → 0.11.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. ninetoothed-0.11.1/.gitattributes +1 -0
  2. ninetoothed-0.11.1/.github/workflows/publish-to-pypi.yml +92 -0
  3. ninetoothed-0.11.1/.github/workflows/sphinx.yml +37 -0
  4. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/PKG-INFO +1 -1
  5. ninetoothed-0.11.1/docs/Makefile +20 -0
  6. ninetoothed-0.11.1/docs/make.bat +35 -0
  7. ninetoothed-0.11.1/docs/requirements.txt +2 -0
  8. ninetoothed-0.11.1/docs/source/_static/matmul-tiling.png +3 -0
  9. ninetoothed-0.11.1/docs/source/_static/ninetoothed-logo.png +3 -0
  10. ninetoothed-0.11.1/docs/source/_static/vecadd-tiling.png +3 -0
  11. ninetoothed-0.11.1/docs/source/code_generation.rst +9 -0
  12. ninetoothed-0.11.1/docs/source/conf.py +28 -0
  13. ninetoothed-0.11.1/docs/source/index.rst +14 -0
  14. ninetoothed-0.11.1/docs/source/installation.rst +12 -0
  15. ninetoothed-0.11.1/docs/source/python_api.rst +9 -0
  16. ninetoothed-0.11.1/docs/source/symbol.rst +4 -0
  17. ninetoothed-0.11.1/docs/source/tensor.rst +18 -0
  18. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/pyproject.toml +1 -1
  19. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/src/ninetoothed/jit.py +31 -3
  20. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/src/ninetoothed/symbol.py +7 -0
  21. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/src/ninetoothed/tensor.py +79 -0
  22. ninetoothed-0.11.1/tests/test_attention.py +92 -0
  23. ninetoothed-0.11.1/tests/test_conv2d.py +62 -0
  24. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/tests/test_matmul.py +18 -12
  25. ninetoothed-0.10.0/docs/source/_static/matmul-tiling.png +0 -0
  26. ninetoothed-0.10.0/docs/source/_static/ninetoothed-logo.png +0 -0
  27. ninetoothed-0.10.0/docs/source/_static/vecadd-tiling.png +0 -0
  28. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/.github/workflows/pytest.yml +0 -0
  29. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/.github/workflows/ruff.yml +0 -0
  30. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/.gitignore +0 -0
  31. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/LICENSE +0 -0
  32. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/README.md +0 -0
  33. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/docs/README.zh.md +0 -0
  34. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/requirements.txt +0 -0
  35. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/src/ninetoothed/__init__.py +0 -0
  36. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/src/ninetoothed/language.py +0 -0
  37. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/src/ninetoothed/naming.py +0 -0
  38. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/src/ninetoothed/torchifier.py +0 -0
  39. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/tests/__init__.py +0 -0
  40. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/tests/skippers.py +0 -0
  41. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/tests/test_add.py +0 -0
  42. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/tests/test_addmm.py +0 -0
  43. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/tests/test_naming.py +0 -0
  44. {ninetoothed-0.10.0 → ninetoothed-0.11.1}/tests/test_softmax.py +0 -0
@@ -0,0 +1 @@
1
+ *.png filter=lfs diff=lfs merge=lfs -text
@@ -0,0 +1,92 @@
1
+ name: Publish Python 🐍 distribution 📦 to PyPI
2
+
3
+ on: push
4
+
5
+ jobs:
6
+ build:
7
+ name: Build distribution 📦
8
+ runs-on: ubuntu-latest
9
+
10
+ steps:
11
+ - uses: actions/checkout@v4
12
+ with:
13
+ persist-credentials: false
14
+ - name: Set up Python
15
+ uses: actions/setup-python@v5
16
+ with:
17
+ python-version: "3.x"
18
+ - name: Install pypa/build
19
+ run: >-
20
+ python3 -m
21
+ pip install
22
+ build
23
+ --user
24
+ - name: Build a binary wheel and a source tarball
25
+ run: python3 -m build
26
+ - name: Store the distribution packages
27
+ uses: actions/upload-artifact@v4
28
+ with:
29
+ name: python-package-distributions
30
+ path: dist/
31
+
32
+ publish-to-pypi:
33
+ name: >-
34
+ Publish Python 🐍 distribution 📦 to PyPI
35
+ if: startsWith(github.ref, 'refs/tags/')
36
+ needs:
37
+ - build
38
+ runs-on: ubuntu-latest
39
+ environment:
40
+ name: pypi
41
+ url: https://pypi.org/p/ninetoothed
42
+ permissions:
43
+ id-token: write
44
+
45
+ steps:
46
+ - name: Download all the dists
47
+ uses: actions/download-artifact@v4
48
+ with:
49
+ name: python-package-distributions
50
+ path: dist/
51
+ - name: Publish distribution 📦 to PyPI
52
+ uses: pypa/gh-action-pypi-publish@release/v1
53
+
54
+ github-release:
55
+ name: >-
56
+ Sign the Python 🐍 distribution 📦 with Sigstore
57
+ and upload them to GitHub Release
58
+ needs:
59
+ - publish-to-pypi
60
+ runs-on: ubuntu-latest
61
+
62
+ permissions:
63
+ contents: write
64
+ id-token: write
65
+
66
+ steps:
67
+ - name: Download all the dists
68
+ uses: actions/download-artifact@v4
69
+ with:
70
+ name: python-package-distributions
71
+ path: dist/
72
+ - name: Sign the dists with Sigstore
73
+ uses: sigstore/gh-action-sigstore-python@v3.0.0
74
+ with:
75
+ inputs: >-
76
+ ./dist/*.tar.gz
77
+ ./dist/*.whl
78
+ - name: Create GitHub Release
79
+ env:
80
+ GITHUB_TOKEN: ${{ github.token }}
81
+ run: >-
82
+ gh release create
83
+ "$GITHUB_REF_NAME"
84
+ --repo "$GITHUB_REPOSITORY"
85
+ --notes ""
86
+ - name: Upload artifact signatures to GitHub Release
87
+ env:
88
+ GITHUB_TOKEN: ${{ github.token }}
89
+ run: >-
90
+ gh release upload
91
+ "$GITHUB_REF_NAME" dist/**
92
+ --repo "$GITHUB_REPOSITORY"
@@ -0,0 +1,37 @@
1
+ name: "Sphinx: Render docs"
2
+
3
+ on: push
4
+
5
+ jobs:
6
+ build:
7
+ runs-on: ubuntu-latest
8
+ permissions:
9
+ contents: write
10
+ steps:
11
+ - uses: actions/checkout@v4
12
+ with:
13
+ persist-credentials: false
14
+ lfs: true
15
+ - name: Set up Python
16
+ uses: actions/setup-python@v5
17
+ with:
18
+ python-version: "3.10"
19
+ - name: Install dependencies
20
+ run: |
21
+ python -m pip install --upgrade pip
22
+ pip install .
23
+ pip install -r docs/requirements.txt
24
+ - name: Build HTML
25
+ run: make -C docs html
26
+ - name: Upload artifacts
27
+ uses: actions/upload-artifact@v4
28
+ with:
29
+ name: html-docs
30
+ path: docs/build/html/
31
+ - name: Deploy
32
+ uses: peaceiris/actions-gh-pages@v3
33
+ if: github.ref == 'refs/heads/master'
34
+ with:
35
+ github_token: ${{ secrets.GITHUB_TOKEN }}
36
+ publish_dir: docs/build/html
37
+ cname: ninetoothed.org
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.10.0
3
+ Version: 0.11.1
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -0,0 +1,20 @@
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = source
9
+ BUILDDIR = build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@@ -0,0 +1,35 @@
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ %SPHINXBUILD% >NUL 2>NUL
14
+ if errorlevel 9009 (
15
+ echo.
16
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17
+ echo.installed, then set the SPHINXBUILD environment variable to point
18
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
19
+ echo.may add the Sphinx directory to PATH.
20
+ echo.
21
+ echo.If you don't have Sphinx installed, grab it from
22
+ echo.https://www.sphinx-doc.org/
23
+ exit /b 1
24
+ )
25
+
26
+ if "%1" == "" goto help
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
@@ -0,0 +1,2 @@
1
+ sphinx
2
+ pydata-sphinx-theme
@@ -0,0 +1,3 @@
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4436ac3d9a7b1362efdbd33cd14ff9be34ceef89d596d4554a4727dd0ef57d5
3
+ size 224240
@@ -0,0 +1,3 @@
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eb81c9d72a5f53d3409a8f56f5bf8ea85145ac788930743a5449692097d4537
3
+ size 17335
@@ -0,0 +1,3 @@
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2d6725fb932a4d96d1c1e7a4c730e2685d00c94d8677f63512d98182de638a1
3
+ size 44416
@@ -0,0 +1,9 @@
1
+ Code Generation
2
+ ===============
3
+
4
+ .. autosummary::
5
+ :toctree: generated
6
+ :nosignatures:
7
+
8
+ ninetoothed.jit
9
+ ninetoothed.make
@@ -0,0 +1,28 @@
1
+ # Configuration file for the Sphinx documentation builder.
2
+ #
3
+ # For the full list of built-in configuration values, see the documentation:
4
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
5
+
6
+ # -- Project information -----------------------------------------------------
7
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8
+
9
+ project = "NineToothed"
10
+ copyright = "2024, NineToothed Contributors"
11
+ author = "NineToothed Contributors"
12
+
13
+ # -- General configuration ---------------------------------------------------
14
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
15
+
16
+ extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosummary"]
17
+
18
+ templates_path = ["_templates"]
19
+ exclude_patterns = []
20
+
21
+
22
+ # -- Options for HTML output -------------------------------------------------
23
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
24
+
25
+ html_theme = "pydata_sphinx_theme"
26
+ html_static_path = ["_static"]
27
+ html_title = "NineToothed"
28
+ html_logo = "_static/ninetoothed-logo.png"
@@ -0,0 +1,14 @@
1
+ NineToothed Documentation
2
+ =========================
3
+
4
+ **NineToothed** is a domain-specific language (DSL) based on Triton, offering higher-level abstractions. Through its tensor-oriented metaprogramming (TOM) model, it empowers developers to write high-performance compute kernels intuitively, without the need to manage low-level details like pointer arithmetic or memory access.
5
+
6
+ .. note::
7
+
8
+ This project is under active development.
9
+
10
+ .. toctree::
11
+ :maxdepth: 2
12
+
13
+ installation
14
+ python_api
@@ -0,0 +1,12 @@
1
+ Installation
2
+ ============
3
+
4
+ You can install NineToothed using ``pip``:
5
+
6
+ .. code-block::
7
+
8
+ pip install ninetoothed
9
+
10
+ To fully leverage its capabilities, you will also need to install a compatible deep learning framework. Currently, NineToothed supports `PyTorch <https://pytorch.org/>`_.
11
+
12
+ It is generally considered good practice to use a virtual environment when installing packages with pip, though it is optional. You may find this `documentation <https://docs.python.org/3/library/venv.html>`_ helpful.
@@ -0,0 +1,9 @@
1
+ Python API
2
+ ==========
3
+
4
+ .. toctree::
5
+ :maxdepth: 1
6
+
7
+ code_generation
8
+ tensor
9
+ symbol
@@ -0,0 +1,4 @@
1
+ Symbol
2
+ ======
3
+
4
+ .. autoclass:: ninetoothed.Symbol
@@ -0,0 +1,18 @@
1
+ Tensor
2
+ ======
3
+
4
+ .. autoclass:: ninetoothed.Tensor
5
+
6
+ Meta-Operations
7
+ ---------------
8
+
9
+ .. autosummary::
10
+ :toctree: generated
11
+ :nosignatures:
12
+
13
+ ninetoothed.Tensor.tile
14
+ ninetoothed.Tensor.expand
15
+ ninetoothed.Tensor.squeeze
16
+ ninetoothed.Tensor.permute
17
+ ninetoothed.Tensor.flatten
18
+ ninetoothed.Tensor.ravel
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "ninetoothed"
7
- version = "0.10.0"
7
+ version = "0.11.1"
8
8
  authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
9
9
  description = "A domain-specific language based on Triton but providing higher-level abstraction."
10
10
  readme = "README.md"
@@ -20,6 +20,13 @@ from ninetoothed.torchifier import Torchifier
20
20
 
21
21
 
22
22
  def make(arrangement, application, tensors):
23
+ """Integrate the arrangement and the application of the tensors.
24
+
25
+ :param arrangement: The arrangement of the tensors.
26
+ :param application: The application of the tensors.
27
+ :param tensors: The tensors.
28
+ :return: A handle to the compute kernel.
29
+ """
23
30
  params = inspect.signature(application).parameters
24
31
  types = arrangement(*tensors)
25
32
  annotations = {param: type for param, type in zip(params, types)}
@@ -28,14 +35,26 @@ def make(arrangement, application, tensors):
28
35
  return jit(application)
29
36
 
30
37
 
31
- def jit(_func=None, *, _prettify=False):
38
+ def jit(func=None, *, _prettify=False):
39
+ """A decorator for generating compute kernels.
40
+
41
+ :param func: The function to be compiled.
42
+ :param _prettify: Whether to prettify the generated code.
43
+ :return: A handle to the compute kernel.
44
+
45
+ .. note::
46
+
47
+ The ``_prettify`` parameter is experimental, which might break
48
+ the generated code.
49
+ """
50
+
32
51
  def wrapper(func):
33
52
  return JIT(func, _prettify=_prettify)()
34
53
 
35
- if _func is None:
54
+ if func is None:
36
55
  return wrapper
37
56
 
38
- return wrapper(_func)
57
+ return wrapper(func)
39
58
 
40
59
 
41
60
  class JIT:
@@ -472,6 +491,15 @@ class CodeGenerator(ast.NodeTransformer):
472
491
  for target_dim in range(tensor.target.ndim)
473
492
  if offsets[source_dim][target_dim] != 0
474
493
  ),
494
+ ) & functools.reduce(
495
+ lambda x, y: x & y,
496
+ (
497
+ indices[dim - tensor.innermost().target.ndim][
498
+ type(self)._generate_slices(tensor, target_dim)
499
+ ]
500
+ < tensor.innermost().target.shape[dim]
501
+ for dim, target_dim in enumerate(tensor.innermost().target_dims)
502
+ ),
475
503
  )
476
504
 
477
505
  return pointers, mask
@@ -7,6 +7,13 @@ import ninetoothed.naming as naming
7
7
 
8
8
 
9
9
  class Symbol:
10
+ """A class uesed to represent a symbol.
11
+
12
+ :param expr: The expression used to construct the symbol.
13
+ :param constexpr: Whether the symbol is a constexpr.
14
+ :param mata: Whether the symbol is a meta.
15
+ """
16
+
10
17
  def __init__(self, expr, constexpr=None, meta=None):
11
18
  if isinstance(expr, type(self)):
12
19
  self._node = expr._node
@@ -8,6 +8,21 @@ from ninetoothed.symbol import Symbol
8
8
 
9
9
 
10
10
  class Tensor:
11
+ """A class uesed to represent a symbolic tensor.
12
+
13
+ :param ndim: The number of dimensions of the tensor.
14
+ :param shape: The shape of the tensor.
15
+ :param dtype: The element type of the tensor.
16
+ :param strides: The strides of the tensor.
17
+ :param other: The values for out-of-bounds positions.
18
+ :param constexpr_shape: Whether the sizes are constexpr.
19
+ :param name: The name of the tensor.
20
+ :param source: For internal use only.
21
+ :param source_dims: For internal use only.
22
+ :param target: For internal use only.
23
+ :param target_dims: For internal use only.
24
+ """
25
+
11
26
  num_instances = 0
12
27
 
13
28
  def __init__(
@@ -70,6 +85,14 @@ class Tensor:
70
85
  type(self).num_instances += 1
71
86
 
72
87
  def tile(self, tile_shape, strides=None, dilation=None):
88
+ """Tiles the tensor into a hierarchical tensor.
89
+
90
+ :param tile_shape: The shape of a tile.
91
+ :param strides: The interval at which each tile is generated.
92
+ :param dilation: The spacing between tiles.
93
+ :return: A hierarchical tensor.
94
+ """
95
+
73
96
  if strides is None:
74
97
  strides = [-1 for _ in tile_shape]
75
98
 
@@ -119,6 +142,12 @@ class Tensor:
119
142
  )
120
143
 
121
144
  def expand(self, shape):
145
+ """Expands the specified singleton dimensions of the tensor.
146
+
147
+ :param shape: The expanded shape.
148
+ :return: The expanded tensor.
149
+ """
150
+
122
151
  # TODO: Add error handling.
123
152
  return type(self)(
124
153
  shape=[
@@ -132,9 +161,16 @@ class Tensor:
132
161
  ],
133
162
  source=self.source,
134
163
  source_dims=self.source_dims,
164
+ target_dims=self.target_dims,
135
165
  )
136
166
 
137
167
  def squeeze(self, dim):
168
+ """Removes the specified singleton dimensions of the tensor.
169
+
170
+ :param dim: The dimension(s) to be squeezed.
171
+ :return: The squeezed tensor.
172
+ """
173
+
138
174
  if not isinstance(dim, tuple):
139
175
  dim = (dim,)
140
176
 
@@ -149,9 +185,20 @@ class Tensor:
149
185
  for i, source_dim in enumerate(self.source_dims)
150
186
  if i not in dim
151
187
  ],
188
+ target_dims=[
189
+ target_dim
190
+ for i, target_dim in enumerate(self.target_dims)
191
+ if i not in dim
192
+ ],
152
193
  )
153
194
 
154
195
  def permute(self, dims):
196
+ """Permutes the dimensions of the tensor.
197
+
198
+ :param dims: The permuted ordering of the dimensions.
199
+ :return: The permuted tensor.
200
+ """
201
+
155
202
  # TODO: Add error handling.
156
203
  new_shape = [None for _ in range(self.ndim)]
157
204
  new_strides = [None for _ in range(self.ndim)]
@@ -168,9 +215,20 @@ class Tensor:
168
215
  strides=new_strides,
169
216
  source=self.source,
170
217
  source_dims=new_source_dims,
218
+ target_dims=self.target_dims,
171
219
  )
172
220
 
173
221
  def flatten(self, start_dim=None, end_dim=None):
222
+ """Flattens the specified dimensions of the tensor.
223
+
224
+ See :func:`ravel` for the differences between :func:`flatten`
225
+ and :func:`ravel`.
226
+
227
+ :param start_dim: The first dimension to flatten.
228
+ :param end_dim: The dimension after the last to flatten.
229
+ :return: The flattened tensor.
230
+ """
231
+
174
232
  # TODO: Add error handling.
175
233
  if start_dim is None:
176
234
  start_dim = 0
@@ -197,15 +255,36 @@ class Tensor:
197
255
  leading_source_dims + (flattening_source_dims,) + trailing_source_dims
198
256
  )
199
257
 
258
+ leading_target_dims = self.target_dims[:start_dim]
259
+ flattening_target_dims = self.target_dims[start_dim:end_dim]
260
+ trailing_target_dims = self.target_dims[end_dim:]
261
+
262
+ new_target_dims = (
263
+ leading_target_dims + (flattening_target_dims[-1],) + trailing_target_dims
264
+ )
265
+
200
266
  return type(self)(
201
267
  shape=new_shape,
202
268
  dtype=self.dtype,
203
269
  strides=new_strides,
204
270
  source=self.source,
205
271
  source_dims=new_source_dims,
272
+ target_dims=new_target_dims,
206
273
  )
207
274
 
208
275
  def ravel(self):
276
+ """Flattens the hierarchy of the tensor.
277
+
278
+ :func:`ravel` differs from :func:`flatten`, which only flattens
279
+ dimensions at a single level. For example, consider a tensor
280
+ with two levels: the first level has a shape of ``(N, P, Q)``,
281
+ and the second level has a shape of ``(C, R, S)``. After
282
+ applying :func:`ravel`, the resulting tensor will have a single
283
+ flattened level with a shape of ``(N, P, Q, C, R, S)``.
284
+
285
+ :return: The raveled tensor.
286
+ """
287
+
209
288
  # TODO: Add error handling.
210
289
  new_shape = []
211
290
  new_strides = []
@@ -0,0 +1,92 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import ninetoothed
5
+ import ninetoothed.language as ntl
6
+ from ninetoothed import Symbol, Tensor
7
+ from tests.skippers import skip_if_cuda_not_available
8
+
9
+
10
+ def arrangement(q, k, v, o):
11
+ BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", constexpr=True)
12
+ BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", constexpr=True)
13
+
14
+ def arrange_q_or_o(input):
15
+ arranged = input.tile((1, 1, BLOCK_SIZE_M, -1))
16
+ arranged.dtype = arranged.dtype.squeeze((0, 1))
17
+
18
+ return arranged
19
+
20
+ def arrange_k_or_v(input):
21
+ arranged = (
22
+ input.tile((1, 1, BLOCK_SIZE_N, -1))
23
+ .tile((1, 1, -1, -1))
24
+ .expand((-1, -1, q_arranged.shape[-2], -1))
25
+ )
26
+ arranged.dtype = arranged.dtype.squeeze((0, 1, 3))
27
+ arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
28
+
29
+ return arranged
30
+
31
+ q_arranged = arrange_q_or_o(q)
32
+
33
+ return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), arrange_q_or_o(o)
34
+
35
+
36
+ def application(q, k, v, o):
37
+ q_loaded = (q * 1.44269504089).to(ntl.float16)
38
+
39
+ acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
40
+ l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32)
41
+ m_i = ntl.full((q.shape[-2],), float("-inf"), dtype=ntl.float32)
42
+
43
+ for i in range(k.shape[0]):
44
+ qk = ntl.dot(q_loaded, ntl.trans(k[i]))
45
+
46
+ m_ij = ntl.maximum(m_i, ntl.max(qk, 1))
47
+ p = ntl.exp2(qk - m_ij[:, None])
48
+ l_ij = ntl.sum(p, 1)
49
+
50
+ alpha = ntl.exp2(m_i - m_ij)
51
+ acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i])
52
+ m_i = m_ij
53
+ l_i = l_i * alpha + l_ij
54
+
55
+ acc /= l_i[:, None]
56
+ o = acc # noqa: F841
57
+
58
+
59
+ def attention(q, k, v):
60
+ o = torch.empty_like(q, dtype=v.dtype)
61
+
62
+ attention_kernel = ninetoothed.make(
63
+ arrangement, application, (Tensor(4, constexpr_shape=True) for _ in range(4))
64
+ )
65
+
66
+ attention_kernel(q, k, v, o, BLOCK_SIZE_M=128, BLOCK_SIZE_N=64)
67
+
68
+ return o
69
+
70
+
71
+ @skip_if_cuda_not_available
72
+ class TestCUDA:
73
+ @classmethod
74
+ def setup_class(cls):
75
+ torch.manual_seed(0)
76
+
77
+ shape = (2, 4, 1024, 64)
78
+
79
+ cls.q = torch.randn(shape, device="cuda")
80
+ cls.k = torch.randn(shape, device="cuda")
81
+ cls.v = torch.randn(shape, device="cuda")
82
+
83
+ def test_fp16(self):
84
+ q = type(self).q.to(torch.float16)
85
+ k = type(self).k.to(torch.float16)
86
+ v = type(self).v.to(torch.float16)
87
+
88
+ assert torch.allclose(
89
+ attention(q, k, v),
90
+ F.scaled_dot_product_attention(q, k, v, scale=1),
91
+ atol=0.01,
92
+ )
@@ -0,0 +1,62 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import ninetoothed
5
+ import tests.test_matmul as matmul
6
+ from ninetoothed import Tensor
7
+ from tests.skippers import skip_if_cuda_not_available
8
+
9
+
10
+ def arrangement(input, filter, output):
11
+ input_tiled = input.tile((1, *filter.shape[1:]), strides=(-1, -1, 1, 1))
12
+ input_squeezed = input_tiled.squeeze(1)
13
+ input_squeezed.dtype = input_squeezed.dtype.squeeze(0)
14
+ input_raveled = input_squeezed.ravel()
15
+ input_flattened = input_raveled.flatten(end_dim=3).flatten(start_dim=1)
16
+
17
+ filter_flattened = filter.flatten(start_dim=1)
18
+ filter_permuted = filter_flattened.permute((1, 0))
19
+
20
+ output_flattened = output.permute((0, 2, 3, 1)).flatten(end_dim=3)
21
+
22
+ return matmul.arrangement(input_flattened, filter_permuted, output_flattened)
23
+
24
+
25
+ def conv2d(input, filter):
26
+ n, _, h, w = input.shape
27
+ k, _, r, s = filter.shape
28
+ p = h - r + 1
29
+ q = w - s + 1
30
+
31
+ output = torch.empty((n, k, p, q), device=input.device, dtype=input.dtype)
32
+
33
+ conv2d_kernel = ninetoothed.make(
34
+ arrangement,
35
+ matmul.application,
36
+ (Tensor(4), Tensor(4, constexpr_shape=True), Tensor(4)),
37
+ )
38
+
39
+ conv2d_kernel(input, filter, output)
40
+
41
+ return output
42
+
43
+
44
+ @skip_if_cuda_not_available
45
+ class TestCUDA:
46
+ @classmethod
47
+ def setup_class(cls):
48
+ torch.manual_seed(0)
49
+
50
+ n, c, h, w = 4, 64, 16, 16
51
+ k, _, r, s = 512, c, 3, 3
52
+
53
+ cls.input = torch.randn(n, c, h, w, device="cuda")
54
+ cls.filter = torch.randn(k, c, r, s, device="cuda")
55
+
56
+ def test_fp16(self):
57
+ input = type(self).input.to(torch.float16)
58
+ filter = type(self).filter.to(torch.float16)
59
+
60
+ assert torch.allclose(
61
+ conv2d(input, filter), F.conv2d(input, filter), atol=0.001, rtol=0.001
62
+ )
@@ -6,40 +6,46 @@ from ninetoothed import Symbol, Tensor
6
6
  from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
7
7
 
8
8
 
9
- def matmul(lhs, rhs):
9
+ def arrangement(lhs, rhs, output):
10
10
  BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
11
11
  BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
12
12
  BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)
13
13
 
14
- output_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
14
+ output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
15
15
 
16
16
  lhs_tiled = (
17
- Tensor(2)
18
- .tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
17
+ lhs.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
19
18
  .tile((1, -1))
20
19
  .expand((-1, output_tiled.shape[1]))
21
20
  )
22
21
  lhs_tiled.dtype = lhs_tiled.dtype.squeeze(0)
23
22
 
24
23
  rhs_tiled = (
25
- Tensor(2)
26
- .tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
24
+ rhs.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
27
25
  .tile((-1, 1))
28
26
  .expand((output_tiled.shape[0], -1))
29
27
  )
30
28
  rhs_tiled.dtype = rhs_tiled.dtype.squeeze(1)
31
29
 
32
- @ninetoothed.jit
33
- def matmul_kernel(lhs: lhs_tiled, rhs: rhs_tiled, output: output_tiled):
34
- accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
35
- for k in range(lhs.shape[0]):
36
- accumulator += ntl.dot(lhs[k], rhs[k])
37
- output = accumulator.to(ntl.float16)
30
+ return lhs_tiled, rhs_tiled, output_tiled
31
+
32
+
33
+ def application(lhs, rhs, output):
34
+ accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
35
+ for k in range(lhs.shape[0]):
36
+ accumulator += ntl.dot(lhs[k], rhs[k])
37
+ output = accumulator.to(ntl.float16)
38
+
38
39
 
40
+ def matmul(lhs, rhs):
39
41
  output = torch.empty(
40
42
  (lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16
41
43
  )
42
44
 
45
+ matmul_kernel = ninetoothed.make(
46
+ arrangement, application, (Tensor(2), Tensor(2), Tensor(2))
47
+ )
48
+
43
49
  matmul_kernel(lhs, rhs, output)
44
50
 
45
51
  return output
File without changes
File without changes
File without changes