ninetoothed 0.16.0__tar.gz → 0.17.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/.github/workflows/pytest.yml +1 -0
  2. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/PKG-INFO +1 -1
  3. ninetoothed-0.17.0/docs/source/basics.rst +254 -0
  4. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/conf.py +22 -0
  5. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/index.rst +1 -0
  6. ninetoothed-0.17.0/docs/source/visualize.py +189 -0
  7. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/pyproject.toml +1 -1
  8. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/aot.py +15 -6
  9. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/generation.py +62 -6
  10. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/language.py +4 -0
  11. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/visualization.py +19 -12
  12. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/test_addmm.py +2 -2
  13. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/test_attention.py +23 -10
  14. ninetoothed-0.17.0/tests/test_dropout.py +54 -0
  15. ninetoothed-0.17.0/tests/test_pow.py +50 -0
  16. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/.gitattributes +0 -0
  17. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/.github/ISSUE_TEMPLATE/bug-report.yml +0 -0
  18. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/.github/ISSUE_TEMPLATE/feature-request.yml +0 -0
  19. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/.github/pull_request_template.md +0 -0
  20. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/.github/workflows/publish-to-pypi.yml +0 -0
  21. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/.github/workflows/ruff.yml +0 -0
  22. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/.github/workflows/sphinx.yml +0 -0
  23. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/.gitignore +0 -0
  24. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/LICENSE +0 -0
  25. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/README.md +0 -0
  26. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/Makefile +0 -0
  27. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/README.zh.md +0 -0
  28. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/make.bat +0 -0
  29. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/requirements.txt +0 -0
  30. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/_static/matmul-tiling.png +0 -0
  31. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/_static/ninetoothed-logo.png +0 -0
  32. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/_static/vecadd-tiling.png +0 -0
  33. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/code_generation.rst +0 -0
  34. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/installation.rst +0 -0
  35. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/python_api.rst +0 -0
  36. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/symbol.rst +0 -0
  37. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/tensor.rst +0 -0
  38. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/docs/source/visualization.rst +0 -0
  39. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/requirements.txt +0 -0
  40. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/__init__.py +0 -0
  41. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/cudaifier.py +0 -0
  42. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/dtype.py +0 -0
  43. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/jit.py +0 -0
  44. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/make.py +0 -0
  45. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/naming.py +0 -0
  46. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/symbol.py +0 -0
  47. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/tensor.py +0 -0
  48. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/torchifier.py +0 -0
  49. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/src/ninetoothed/utils.py +0 -0
  50. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/__init__.py +0 -0
  51. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/skippers.py +0 -0
  52. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/test_add.py +0 -0
  53. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/test_aot.py +0 -0
  54. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/test_conv2d.py +0 -0
  55. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/test_matmul.py +0 -0
  56. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/test_max_pool2d.py +0 -0
  57. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/test_naming.py +0 -0
  58. {ninetoothed-0.16.0 → ninetoothed-0.17.0}/tests/test_softmax.py +0 -0
@@ -12,6 +12,7 @@ jobs:
12
12
  - name: Install dependencies
13
13
  run: |
14
14
  python -m pip install --upgrade pip
15
+ pip install -e .[all]
15
16
  pip install -r requirements.txt
16
17
  - name: Test with pytest
17
18
  run: |
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.16.0
3
+ Version: 0.17.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -0,0 +1,254 @@
1
+ The Basics
2
+ ==========
3
+
4
+ Symbols
5
+ -------
6
+
7
+ The concept of **symbols** is similar to what is described in the `SymPy tutorial <https://docs.sympy.org/latest/tutorials/intro-tutorial/intro.html>`_. Symbols do not store actual numerical values; instead, they represent symbolic names or symbolic expressions. This allows for performing symbolic mathematical operations.
8
+
9
+ In NineToothed, you can create a symbol using the ``Symbol`` class. For example, in the code below, we first create two symbols named ``BLOCK_SIZE_M`` and ``BLOCK_SIZE_N``, and then perform a multiplication operation on them:
10
+
11
+ .. code-block::
12
+
13
+ >>> from ninetoothed import Symbol
14
+ >>> BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M")
15
+ >>> BLOCK_SIZE_M
16
+ BLOCK_SIZE_M
17
+ >>> BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N")
18
+ >>> BLOCK_SIZE_N
19
+ BLOCK_SIZE_N
20
+ >>> BLOCK_SIZE_M * BLOCK_SIZE_N
21
+ BLOCK_SIZE_M * BLOCK_SIZE_N
22
+
23
+ Symbolic Tensors
24
+ ----------------
25
+
26
+ Similar to many deep learning frameworks, tensors are a core concept in NineToothed. However, tensors in NineToothed differ slightly from those in other frameworks—they do not store actual data. Instead, they store symbolic expressions in member variables such as ``shape`` and ``strides``. For this reason, we refer to them as **symbolic tensors**.
27
+
28
+ In NineToothed, you can create a tensor using the ``Tensor`` class. As shown in the example below, ``Tensor(2)`` creates a 2-dimensional tensor—essentially, a matrix. Note that the ``shape`` member contains symbolic expressions rather than concrete values:
29
+
30
+ .. code-block::
31
+
32
+ >>> from ninetoothed import Tensor
33
+ >>> x = Tensor(2)
34
+ >>> x.shape
35
+ (ninetoothed_tensor_0_size_0, ninetoothed_tensor_0_size_1)
36
+
37
+ Tensor-Oriented Metaprogramming
38
+ -------------------------------
39
+
40
+ Thanks to symbolic tensors, we can perform certain compile-time operations on tensors in NineToothed. These operations are called **meta-operations**, such as ``tile``, ``expand``, ``squeeze``, ``permute``, and so on.
41
+
42
+ For example, in the following code, we apply the ``tile`` operation to ``x``, which divides ``x`` into blocks of shape ``(BLOCK_SIZE_M, BLOCK_SIZE_N)``:
43
+
44
+ .. code-block::
45
+
46
+ >>> x_tiled = x.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
47
+ >>> x_tiled.shape
48
+ ((ninetoothed_tensor_0_size_0 - (BLOCK_SIZE_M - 1) - 1 + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + 1, (ninetoothed_tensor_0_size_1 - (BLOCK_SIZE_N - 1) - 1 + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + 1)
49
+ >>> x_tiled.dtype.shape
50
+ (BLOCK_SIZE_M, BLOCK_SIZE_N)
51
+
52
+ We notice that the ``dtype`` of ``x_tiled`` also has a ``shape`` attribute. This is because tensors in NineToothed can be nested—that is, the elements of a tensor can themselves be tensors.
53
+
54
+ In other words, during the ``tile`` operation, we create a two-level tensor: Each element of the outer tensor is itself an inner tensor. To make this easier to understand, let's walk through a numerical example:
55
+
56
+ .. code-block::
57
+
58
+ >>> BLOCK_SIZE_M = 2
59
+ >>> BLOCK_SIZE_N = 2
60
+ >>> x = Tensor(shape=(4, 8))
61
+ >>> x_tiled = x.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
62
+ >>> x_tiled.shape
63
+ (2, 4)
64
+ >>> x_tiled.dtype.shape
65
+ (2, 2)
66
+
67
+ As shown in the figure below, we've tiled the original tensor ``x`` of shape ``(4, 8)`` into blocks of shape ``(2, 2)`` (the inner tensors), resulting in a total of ``(2, 4)`` such blocks (the outer tensor):
68
+
69
+ .. image:: generated/x-tiled.png
70
+
71
+ Arrange-and-Apply Paradigm
72
+ --------------------------
73
+
74
+ After introducing meta-operations, you should be able to understand what operations can be performed on individual tensors at compile time. A series of such operations is referred to as **arrangement**. However, this is not enough, as we also need to establish the relationship between multiple parameter tensors.
75
+
76
+ Such relationships are managed by the NineToothed compiler: The NineToothed compiler launches programs based on the shape of the outermost tensors of the arranged parameter tensors and maps the second outermost tensors to these programs.
77
+
78
+ We can understand this concept using a simple arrangement function:
79
+
80
+ .. code-block:: python
81
+
82
+ def arrangement(x, y, z, BLOCK_SIZE=ninetoothed.block_size()):
83
+ return x.tile((BLOCK_SIZE,)), y.tile((BLOCK_SIZE,)), z.tile((BLOCK_SIZE,))
84
+
85
+ In this function, we apply the ``tile`` operation to the vectors ``x``, ``y``, and ``z`` to divide each vector into blocks of size ``BLOCK_SIZE``. For example, if each vector's length is ``16`` and ``BLOCK_SIZE`` is ``2``, each vector can be divided into ``8`` blocks, with each block having a length of ``2``. The arranged ``x``, ``y``, and ``z`` would then look as follows:
86
+
87
+ .. image:: generated/x-arranged.png
88
+
89
+ .. image:: generated/y-arranged.png
90
+
91
+ .. image:: generated/z-arranged.png
92
+
93
+ Based on this arrangement, the NineToothed compiler will launch ``8`` programs and map the elements of the outermost tensors of the arranged ``x``, ``y``, and ``z`` (i.e., the second outermost tensors) to these ``8`` programs.
94
+
95
+ Now that we have these mappings, we can launch the programs accordingly. However, we are still one step away from fully implementing an algorithm, because we have not defined what each program should do. In other words, we need to define how to apply the arranged tensors. In NineToothed, this can be done by defining an **application** function.
96
+
97
+ For example, to define vector addition, we can create the following application function:
98
+
99
+ .. code-block:: python
100
+
101
+ def application(x, y, z):
102
+ z = x + y
103
+
104
+ The logic of the code is simple: It adds ``x`` and ``y`` and stores the result in ``z``. However, it is important to note that the parameters of the application function are the elements of the outermost tensors of the arranged parameter tensors (i.e., the second outermost tensors), not the tensors themselves. That is, based on the above assumptions, ``x``, ``y``, and ``z`` here represent blocks of length ``2``, not the original tensors of length ``16``.
105
+
106
+ At this point, we have defined both the arrangement and application functions. The remaining task is to integrate them into a compute kernel. In NineToothed, we can use ``ninetoothed.make`` to achieve this:
107
+
108
+ .. code-block:: python
109
+
110
+ kernel = ninetoothed.make(arrangement, application, (Tensor(1), Tensor(1), Tensor(1)))
111
+
112
+ This code means that we want to arrange three 1-dimensional tensors (vectors) according to the ``arrangement`` function, and apply the arranged tensors using the ``application`` function to form a compute kernel ``kernel``. The paradigm of constructing a compute kernel this way is called the **arrange-and-apply paradigm**.
113
+
114
+ We can invoke ``kernel`` as follows:
115
+
116
+ .. code-block:: python
117
+
118
+ import torch
119
+
120
+ dtype = torch.float16
121
+ device = "cuda"
122
+
123
+ x = torch.tensor((1, 2, 3), dtype=dtype, device=device)
124
+ y = torch.tensor((4, 5, 6), dtype=dtype, device=device)
125
+
126
+ z = torch.empty_like(x)
127
+ kernel(x, y, z)
128
+
129
+ reference = torch.tensor((5, 7, 9), dtype=dtype, device=device)
130
+ assert torch.allclose(z, reference)
131
+
132
+ As shown, when we call ``kernel``, we do not provide an actual value for ``BLOCK_SIZE``. This is because when constructing ``BLOCK_SIZE``, we used ``ninetoothed.block_size``, which represents that we want to use the configurations generated by the NineToothed compiler for auto-tuning. If we want to provide a value manually (for example, during debugging), we can directly assign a specific value as follows:
133
+
134
+ .. code-block:: python
135
+
136
+ def arrangement(x, y, z, BLOCK_SIZE=1024):
137
+ return x.tile((BLOCK_SIZE,)), y.tile((BLOCK_SIZE,)), z.tile((BLOCK_SIZE,))
138
+
139
+ Indexing and Iteration
140
+ ----------------------
141
+
142
+ Through the vector addition example, we got a brief understanding of how to develop compute kernels using NineToothed. In that example, the parameter tensors were arranged into two-level tensors. However, tensors in NineToothed are not limited to just two levels—they can be three-level or even more. Only the outermost level of an arranged tensor is used to launch programs. In other words, tensors with more than two levels are hierarchical and can be indexed and iterated over within the application function.
143
+
144
+ Now, let's implement a matrix multiplication kernel to better understand indexing and iteration in NineToothed, as well as deepen our understanding of the arrange-and-apply paradigm.
145
+
146
+ Before we begin implementation, we first need to understand the algorithm we want to realize. Here's a diagram for the algorithm:
147
+
148
+ .. image:: generated/tiled-matrix-multiplication.png
149
+
150
+ In simple terms, we tile three matrices into blocks. For each block in :math:`C`, we need to iterate over the corresponding row of blocks of :math:`A` and the corresponding column of blocks of :math:`B`. Then, for each iteration, we need to perform a small matrix multiplication between the blocks of :math:`A` and :math:`B`, and accumulate the result into the block of :math:`C`.
151
+
152
+ With this algorithm in mind, let's begin the implementation, starting with the arrangement phase:
153
+
154
+ .. code-block:: python
155
+
156
+ BLOCK_SIZE_M = ninetoothed.block_size()
157
+ BLOCK_SIZE_N = ninetoothed.block_size()
158
+ BLOCK_SIZE_K = ninetoothed.block_size()
159
+
160
+
161
+ def arrangement(input, other, output):
162
+ output_arranged = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
163
+
164
+ input_arranged = input.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
165
+ input_arranged = input_arranged.tile((1, -1))
166
+ input_arranged = input_arranged.expand((-1, output_arranged.shape[1]))
167
+ input_arranged.dtype = input_arranged.dtype.squeeze(0)
168
+
169
+ other_arranged = other.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
170
+ other_arranged = other_arranged.tile((-1, 1))
171
+ other_arranged = other_arranged.expand((output_arranged.shape[0], -1))
172
+ other_arranged.dtype = other_arranged.dtype.squeeze(1)
173
+
174
+ return input_arranged, other_arranged, output_arranged
175
+
176
+ In this code, we first define the symbols ``BLOCK_SIZE_M``, ``BLOCK_SIZE_N``, and ``BLOCK_SIZE_K``, which represent the shapes of the blocks. We then tile ``output`` into blocks of shape ``(BLOCK_SIZE_M, BLOCK_SIZE_N)``, ``input`` into ``(BLOCK_SIZE_M, BLOCK_SIZE_K)``, and ``other`` into ``(BLOCK_SIZE_K, BLOCK_SIZE_N)``:
177
+
178
+ .. image:: generated/input-arranged-0.png
179
+
180
+ .. image:: generated/other-arranged-0.png
181
+
182
+ .. image:: generated/output-arranged-0.png
183
+
184
+ We notice that simple arrangement is not enough for matrix multiplication. According to the diagram, each block in ``output`` corresponds to a row of blocks in ``input`` and a column of blocks in ``other``. So we need to further tile ``input`` row-wise and ``other`` column-wise:
185
+
186
+ .. image:: generated/input-arranged-1.png
187
+
188
+ .. image:: generated/other-arranged-1.png
189
+
190
+ But we're still not done. Remember how the NineToothed compiler establishes the relationship between multiple parameter tensors?
191
+
192
+ The NineToothed compiler launches programs based on the shape of the outermost tensors of the arranged parameter tensors and maps the second outermost tensors to these programs.
193
+
194
+ Why is this important? Because it implies a crucial rule: The outermost tensors of the arranged parameter tensors must have the same shape.
195
+
196
+ Currently, the shapes of the outermost tensors of the arranged parameter tensors are ``(4, 1)``, ``(1, 4)``, and ``(4, 4)``—clearly inconsistent. This suggests that the arrangement is incorrect or incomplete. From the diagram, we know we need to align each row of blocks of ``input`` with each column of blocks of ``other``. We can achieve this via ``expand``, horizontally expanding ``input`` and vertically expanding ``other`` to match the shape of ``output``:
197
+
198
+ .. image:: generated/input-arranged-2.png
199
+
200
+ .. image:: generated/other-arranged-2.png
201
+
202
+ Now, the outermost tensors of the arranged parameter tensors have matching shapes. Technically, arrangement is complete and we could proceed to write the application function. However, we notice that the row of blocks of ``input`` and the column of blocks of ``other`` are two-dimensional, and their shapes are of the form ``(1, ...)`` and ``(..., 1)`` respectively. In other words, if we do not perform additional operations, the way to index the row of blocks and the column of blocks would be ``input[0, k]`` and ``other[k, 0]``. If we want to find the range of ``k`` based on ``input``, we would need to use ``input.shape[1]``. But we know that dimensions of size ``1`` can be safely removed here. That's why we add ``squeeze``:
203
+
204
+ .. image:: generated/input-arranged-3.png
205
+
206
+ .. image:: generated/other-arranged-3.png
207
+
208
+ With this, we can now index the row of blocks and the column of blocks with ``input[k]`` and ``other[k]``, and use ``input.shape[0]`` to determine the range of ``k``.
209
+
210
+ At this point, the arrangement phase is complete. The final arrangement result is:
211
+
212
+ .. image:: generated/input-arranged-3.png
213
+
214
+ .. image:: generated/other-arranged-3.png
215
+
216
+ .. image:: generated/output-arranged-0.png
217
+
218
+ Now let's look at the application function:
219
+
220
+ .. code-block:: python
221
+
222
+ def application(input, other, output):
223
+ accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
224
+
225
+ for k in range(input.shape[0]):
226
+ accumulator += ntl.dot(input[k], other[k])
227
+
228
+ output = accumulator
229
+
230
+ Within the function body, we first define an ``accumulator`` to accumulate intermediate results. Then, we iterate over the row of blocks of ``input`` and the column of blocks of ``other``, and accumulate the results of small matrix multiplications into the ``accumulator``. Finally, we write the ``accumulator`` into the corresponding block of ``output``. Since this happens on each block of ``output``, the overall matrix multiplication is completed.
231
+
232
+ Just like in the vector addition example, after defining ``arrangement`` and ``application``, we can integrate them using ``ninetoothed.make`` to build a kernel:
233
+
234
+ .. code-block:: python
235
+
236
+ kernel = ninetoothed.make(arrangement, application, (Tensor(2), Tensor(2), Tensor(2)))
237
+
238
+ The kernel can be invoked like this:
239
+
240
+ .. code-block:: python
241
+
242
+ import torch
243
+
244
+ dtype = torch.float16
245
+ device = "cuda"
246
+
247
+ input = torch.tensor(((1, 2), (3, 4)), dtype=dtype, device=device)
248
+ other = torch.tensor(((5, 6), (7, 8)), dtype=dtype, device=device)
249
+
250
+ output = torch.empty((input.shape[0], other.shape[1]), dtype=dtype, device=device)
251
+ kernel(input, other, output)
252
+
253
+ reference = torch.tensor(((19, 22), (43, 50)), dtype=dtype, device=device)
254
+ assert torch.allclose(output, reference)
@@ -3,6 +3,18 @@
3
3
  # For the full list of built-in configuration values, see the documentation:
4
4
  # https://www.sphinx-doc.org/en/master/usage/configuration.html
5
5
 
6
+ import os
7
+ import sys
8
+
9
+ sys.path.insert(0, os.path.abspath("."))
10
+
11
+ from visualize import (
12
+ visualize_add,
13
+ visualize_mm,
14
+ visualize_tiled_matrix_multiplication,
15
+ visualize_x_tiled,
16
+ )
17
+
6
18
  # -- Project information -----------------------------------------------------
7
19
  # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8
20
 
@@ -36,3 +48,13 @@ html_theme_options = {
36
48
  }
37
49
  ]
38
50
  }
51
+
52
+ os.makedirs("generated", exist_ok=True)
53
+
54
+ visualize_x_tiled(4, 8, 2, 2)
55
+
56
+ visualize_add(16, 2)
57
+
58
+ visualize_tiled_matrix_multiplication(8, 8, 8, 2, 2, 2)
59
+
60
+ visualize_mm(8, 8, 8, 2, 2, 2)
@@ -11,5 +11,6 @@ NineToothed Documentation
11
11
  :maxdepth: 2
12
12
 
13
13
  installation
14
+ basics
14
15
  python_api
15
16
  visualization
@@ -0,0 +1,189 @@
1
+ import matplotlib.pyplot as plt
2
+
3
+ from ninetoothed import Tensor
4
+ from ninetoothed.visualization import (
5
+ _prepare_figure_and_axes,
6
+ _visualize_tensor,
7
+ visualize,
8
+ )
9
+
10
+
11
+ def visualize_x_tiled(m, n, block_size_m, block_size_n):
12
+ BLOCK_SIZE_M = block_size_m
13
+ BLOCK_SIZE_N = block_size_n
14
+
15
+ x = Tensor(shape=(m, n))
16
+ x_tiled = x.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
17
+
18
+ visualize(x_tiled, color="C0", save_path="generated/x-tiled.png")
19
+
20
+
21
+ def visualize_add(size, block_size):
22
+ BLOCK_SIZE = block_size
23
+
24
+ x = Tensor(shape=(size,))
25
+ y = Tensor(shape=(size,))
26
+ z = Tensor(shape=(size,))
27
+
28
+ x_arranged = x.tile((BLOCK_SIZE,))
29
+ visualize(x_arranged, color="C0", save_path="generated/x-arranged.png")
30
+
31
+ y_arranged = y.tile((BLOCK_SIZE,))
32
+ visualize(y_arranged, color="C1", save_path="generated/y-arranged.png")
33
+
34
+ z_arranged = z.tile((BLOCK_SIZE,))
35
+ visualize(z_arranged, color="C2", save_path="generated/z-arranged.png")
36
+
37
+
38
+ def visualize_tiled_matrix_multiplication(
39
+ m, n, k, block_size_m, block_size_n, block_size_k
40
+ ):
41
+ BLOCK_SIZE_M = block_size_m
42
+ BLOCK_SIZE_N = block_size_n
43
+ BLOCK_SIZE_K = block_size_k
44
+
45
+ a = Tensor(shape=(m, k))
46
+ b = Tensor(shape=(k, n))
47
+ c = Tensor(shape=(m, n))
48
+
49
+ a_tiled = a.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
50
+ b_tiled = b.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
51
+ c_tiled = c.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
52
+
53
+ a_tile = a_tiled.innermost()
54
+ b_tile = b_tiled.innermost()
55
+ c_tile = c_tiled.innermost()
56
+
57
+ color_0 = "#1f77b4"
58
+ color_1 = "#ff7f0e"
59
+ color_2 = "#2ca02c"
60
+
61
+ a_color = _lighten_color(color_0, 20)
62
+ b_color = _lighten_color(color_1, 20)
63
+ c_color = _lighten_color(color_2, 20)
64
+
65
+ a_tile_base_color = color_0
66
+ b_tile_base_color = color_1
67
+ c_tile_base_color = color_2
68
+
69
+ def _visualize_matrices(ax):
70
+ a_verts, a_max_pos_x, a_max_pos_y = _visualize_tensor(ax, a, 0, 0, a_color)
71
+ b_verts, b_max_pos_x, b_max_pos_y = _visualize_tensor(
72
+ ax, b, a_max_pos_x + 2, a_max_pos_y + 2, b_color
73
+ )
74
+ c_verts, _, _ = _visualize_tensor(ax, c, 0, a_max_pos_y + 2, c_color)
75
+
76
+ a_min_pos_x, a_min_pos_y = a_verts[0][0]
77
+ b_min_pos_x, b_min_pos_y = b_verts[0][0]
78
+ c_min_pos_x, c_min_pos_y = c_verts[0][0]
79
+
80
+ percentage = 30
81
+
82
+ for i in range(0, a_tiled.shape[1]):
83
+ y_offset = i * a_tile.shape[1]
84
+
85
+ x = a_min_pos_x + (a_tiled.shape[0] - 2) * a_tile.shape[0]
86
+ y = a_min_pos_y + y_offset
87
+
88
+ darkened_color = _darken_color(
89
+ a_tile_base_color, (i + 1) / a_tiled.shape[1] * percentage
90
+ )
91
+
92
+ _visualize_tensor(ax, a_tile, x, y, darkened_color)
93
+
94
+ for i in range(0, b_tiled.shape[0]):
95
+ x_offset = i * b_tile.shape[0]
96
+
97
+ x = b_min_pos_x + x_offset
98
+ y = b_min_pos_y + b_tile.shape[1]
99
+
100
+ darkened_color = _darken_color(
101
+ b_tile_base_color,
102
+ (b_tiled.shape[0] - i) / b_tiled.shape[0] * percentage,
103
+ )
104
+
105
+ _visualize_tensor(ax, b_tile, x, y, darkened_color)
106
+
107
+ _visualize_tensor(
108
+ ax,
109
+ c_tile,
110
+ c_min_pos_x + (a_tiled.shape[0] - 2) * a_tile.shape[0],
111
+ c_min_pos_y + b_tile.shape[1],
112
+ _darken_color(c_tile_base_color, percentage),
113
+ )
114
+
115
+ return b_max_pos_x, b_max_pos_y
116
+
117
+ max_pos_x, max_pos_y = _visualize_matrices(plt.gca())
118
+
119
+ width = max_pos_y + 1
120
+ height = max_pos_x + 1
121
+
122
+ _, ax = _prepare_figure_and_axes(width, height)
123
+
124
+ _visualize_matrices(ax)
125
+
126
+ save_path = "generated/tiled-matrix-multiplication.png"
127
+ plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
128
+
129
+ plt.close()
130
+
131
+
132
+ def visualize_mm(m, n, k, block_size_m, block_size_n, block_size_k):
133
+ BLOCK_SIZE_M = block_size_m
134
+ BLOCK_SIZE_N = block_size_n
135
+ BLOCK_SIZE_K = block_size_k
136
+
137
+ input = Tensor(shape=(m, k))
138
+ other = Tensor(shape=(k, n))
139
+ output = Tensor(shape=(m, n))
140
+
141
+ output_arranged = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
142
+ visualize(output_arranged, color="C2", save_path="generated/output-arranged-0.png")
143
+
144
+ input_arranged = input.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
145
+ visualize(input_arranged, color="C0", save_path="generated/input-arranged-0.png")
146
+ input_arranged = input_arranged.tile((1, -1))
147
+ visualize(input_arranged, color="C0", save_path="generated/input-arranged-1.png")
148
+ input_arranged = input_arranged.expand((-1, output_arranged.shape[1]))
149
+ visualize(input_arranged, color="C0", save_path="generated/input-arranged-2.png")
150
+ input_arranged.dtype = input_arranged.dtype.squeeze(0)
151
+ visualize(input_arranged, color="C0", save_path="generated/input-arranged-3.png")
152
+
153
+ other_arranged = other.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
154
+ visualize(other_arranged, color="C1", save_path="generated/other-arranged-0.png")
155
+ other_arranged = other_arranged.tile((-1, 1))
156
+ visualize(other_arranged, color="C1", save_path="generated/other-arranged-1.png")
157
+ other_arranged = other_arranged.expand((output_arranged.shape[0], -1))
158
+ visualize(other_arranged, color="C1", save_path="generated/other-arranged-2.png")
159
+ other_arranged.dtype = other_arranged.dtype.squeeze(1)
160
+ visualize(other_arranged, color="C1", save_path="generated/other-arranged-3.png")
161
+
162
+
163
+ def _darken_color(hex_color, percentage):
164
+ rgb = _hex_to_rgb(hex_color)
165
+ factor = (100 - percentage) / 100
166
+ darkened_rgb = tuple(int(c * factor) for c in rgb)
167
+
168
+ return _rgb_to_hex(darkened_rgb)
169
+
170
+
171
+ def _lighten_color(hex_color, percentage):
172
+ rgb = _hex_to_rgb(hex_color)
173
+ factor = percentage / 100
174
+ lightened_rgb = tuple(int(c + (255 - c) * factor) for c in rgb)
175
+
176
+ return _rgb_to_hex(lightened_rgb)
177
+
178
+
179
+ def _hex_to_rgb(hex_color):
180
+ hex_color = hex_color.lstrip("#")
181
+
182
+ if len(hex_color) == 3:
183
+ hex_color = "".join(c * 2 for c in hex_color)
184
+
185
+ return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
186
+
187
+
188
+ def _rgb_to_hex(rgb_color):
189
+ return "#{:02x}{:02x}{:02x}".format(*rgb_color)
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "ninetoothed"
7
- version = "0.16.0"
7
+ version = "0.17.0"
8
8
  authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
9
9
  description = "A domain-specific language based on Triton but providing higher-level abstraction."
10
10
  readme = "README.md"
@@ -31,7 +31,7 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
31
31
 
32
32
  _HEADER_PATH.parent.mkdir(exist_ok=True)
33
33
 
34
- if not _HEADER_PATH.exists():
34
+ if not _HEADER_PATH.exists() or _HEADER_PATH.read_text() != _HEADER_CONTENT:
35
35
  _HEADER_PATH.write_text(_HEADER_CONTENT)
36
36
 
37
37
  code_generator = CodeGenerator()
@@ -91,20 +91,29 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
91
91
 
92
92
  c_header_file_name = f"{kernel_name}.{signature_hash}.h"
93
93
  c_header_file = output_contents[c_header_file_name]
94
- c_header_file = f"{c_header_file}\n{unparser.header};\n"
94
+ c_header_file = f'{c_header_file}\n#ifdef __cplusplus\nextern "C" {unparser.header};\n#else\n{unparser.header};\n#endif\n'
95
95
  c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
96
96
  output_contents[c_header_file_name] = c_header_file
97
97
 
98
98
  return output_contents
99
99
 
100
100
 
101
- _HEADER_CONTENT = """#include <stdint.h>
101
+ _HEADER_CONTENT = """#ifndef NINETOOTHED_H
102
+ #define NINETOOTHED_H
103
+
104
+ #include <stdint.h>
102
105
 
103
106
  typedef struct {
104
- uintptr_t data;
107
+ void *data;
105
108
  uint64_t *shape;
106
109
  int64_t *strides;
107
110
  } NineToothedTensor;
111
+
112
+ typedef void *NineToothedStream;
113
+
114
+ typedef int NineToothedResult;
115
+
116
+ #endif // NINETOOTHED_H
108
117
  """
109
118
 
110
119
  _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
@@ -135,9 +144,9 @@ class _Unparser:
135
144
  return f"return {self._generic_unparse(call)};"
136
145
 
137
146
  def _unparse_FunctionDef(self, node):
138
- params = ["CUstream stream"]
147
+ params = ["NineToothedStream stream"]
139
148
  params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
140
- header = f"CUresult {node.name}({', '.join(params)})"
149
+ header = f"NineToothedResult {node.name}({', '.join(params)})"
141
150
 
142
151
  self.header = header
143
152
 
@@ -19,6 +19,7 @@ import uuid
19
19
  import sympy
20
20
  import triton
21
21
  import triton.language as tl
22
+ from triton.language.extra import libdevice
22
23
 
23
24
  import ninetoothed.naming as naming
24
25
  from ninetoothed.cudaifier import Cudaifier
@@ -225,6 +226,41 @@ class CodeGenerator(ast.NodeTransformer):
225
226
 
226
227
  return node
227
228
 
229
+ def visit_Call(self, node):
230
+ def _offsets(tensor, dim=None):
231
+ if dim is None:
232
+ return tensor._last_generated_overall_offsets.node
233
+
234
+ offsets = tensor._last_generated_offsets
235
+
236
+ if dim < 0:
237
+ dim += tensor.source.ndim
238
+
239
+ return sum(
240
+ offsets[dim][target_dim] for target_dim in range(tensor.target.ndim)
241
+ ).node
242
+
243
+ func = node.func
244
+ args = node.args
245
+
246
+ if isinstance(func, ast.Attribute):
247
+ if func.attr == "offsets":
248
+ value = func.value
249
+
250
+ if self._in_context(value):
251
+ tensor = self._context[value.id]
252
+ elif isinstance(value, ast.Subscript) and self._in_context(value.value):
253
+ tensor = self._context[value.value.id]
254
+
255
+ self.visit(value)
256
+
257
+ # TODO: Add error handling.
258
+ return _offsets(tensor, ast.literal_eval(args[0]) if args else None)
259
+
260
+ self.generic_visit(node)
261
+
262
+ return node
263
+
228
264
  def visit_Subscript(self, node):
229
265
  if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
230
266
  value = self._context[node.value.id]
@@ -242,13 +278,24 @@ class CodeGenerator(ast.NodeTransformer):
242
278
  return node
243
279
 
244
280
  def visit_Attribute(self, node):
245
- if self._in_context(node.value):
246
- value = self._context[node.value.id]
281
+ value = node.value
247
282
 
248
- if isinstance(value, Tensor):
249
- inner = value.dtype
283
+ if isinstance(value, ast.Attribute):
284
+ value = self.visit_Attribute(value)
285
+
286
+ if self._in_context(value):
287
+ value = self._context[value.id].dtype
288
+
289
+ if isinstance(value, Tensor):
290
+ attr = getattr(value, node.attr)
291
+
292
+ if node.attr == "dtype" and attr is None:
293
+ return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
250
294
 
251
- return Symbol(getattr(inner, node.attr)).node
295
+ if isinstance(attr, Tensor):
296
+ return attr
297
+
298
+ return Symbol(attr).node
252
299
 
253
300
  self.generic_visit(node)
254
301
 
@@ -560,6 +607,8 @@ class CodeGenerator(ast.NodeTransformer):
560
607
  indices = self._complete_indices(tensor, indices)
561
608
  offsets = type(self)._generate_offsets(tensor, indices)
562
609
 
610
+ tensor._last_generated_offsets = offsets
611
+
563
612
  for source_dim in range(tensor.source.ndim):
564
613
  for target_dim in range(tensor.target.ndim):
565
614
  if target_dim not in invariant_target_dims:
@@ -584,7 +633,7 @@ class CodeGenerator(ast.NodeTransformer):
584
633
  * tensor.source.strides[source_dim]
585
634
  )
586
635
 
587
- pointers = name_for_pointers + sum(
636
+ overall_offsets = sum(
588
637
  offsets[source_dim][target_dim][
589
638
  type(self)._generate_slices(tensor, target_dim)
590
639
  ]
@@ -594,6 +643,10 @@ class CodeGenerator(ast.NodeTransformer):
594
643
  if target_dim not in invariant_target_dims
595
644
  and offsets[source_dim][target_dim] != 0
596
645
  )
646
+
647
+ tensor._last_generated_overall_offsets = overall_offsets
648
+
649
+ pointers = name_for_pointers + overall_offsets
597
650
  mask = functools.reduce(
598
651
  lambda x, y: x & y,
599
652
  (
@@ -980,6 +1033,9 @@ class _Inliner(ast.NodeTransformer):
980
1033
  if func_def is None:
981
1034
  return None, []
982
1035
 
1036
+ if inspect.getmodule(func) is libdevice:
1037
+ return None, []
1038
+
983
1039
  collector = _ImportCollector()
984
1040
  collector.visit(ast.parse(inspect.getsource(inspect.getmodule(func))))
985
1041
  self.imports.extend(collector.imports)
@@ -1,7 +1,11 @@
1
1
  import ast
2
2
 
3
+ from triton.language.extra import libdevice
4
+
3
5
  from ninetoothed.symbol import Symbol
4
6
 
7
+ __all__ = ["libdevice"]
8
+
5
9
  LANGUAGE = "ninetoothed.language"
6
10
 
7
11
 
@@ -10,8 +10,6 @@ def visualize(tensor, color=None, save_path=None):
10
10
  :param color: The color to be used for visualization.
11
11
  :param save_path: The path where the visualization should be saved.
12
12
  """
13
- outline_width = 0.1
14
- plt.rcParams["lines.linewidth"] = 72 * outline_width
15
13
 
16
14
  if color is None:
17
15
  color = f"C{visualize.count}"
@@ -21,6 +19,24 @@ def visualize(tensor, color=None, save_path=None):
21
19
  width = max_pos_y + 1
22
20
  height = max_pos_x + 1
23
21
 
22
+ _, ax = _prepare_figure_and_axes(width, height)
23
+
24
+ _visualize_tensor(ax, tensor, 0, 0, color)
25
+
26
+ plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
27
+
28
+ plt.close()
29
+
30
+ visualize.count += 1
31
+
32
+
33
+ visualize.count = 0
34
+
35
+
36
+ def _prepare_figure_and_axes(width, height):
37
+ outline_width = 0.1
38
+ plt.rcParams["lines.linewidth"] = 72 * outline_width
39
+
24
40
  fig = plt.figure(figsize=(width + outline_width, height + outline_width))
25
41
 
26
42
  h = (Size.Fixed(0), Size.Fixed(width + outline_width))
@@ -41,16 +57,7 @@ def visualize(tensor, color=None, save_path=None):
41
57
  plt.xlim((-half_outline_width, width + half_outline_width))
42
58
  plt.ylim((-half_outline_width, height + half_outline_width))
43
59
 
44
- _visualize_tensor(ax, tensor, 0, 0, color)
45
-
46
- plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
47
-
48
- plt.close()
49
-
50
- visualize.count += 1
51
-
52
-
53
- visualize.count = 0
60
+ return fig, ax
54
61
 
55
62
 
56
63
  def _visualize_tensor(ax, tensor, x, y, color, level_spacing=4):
@@ -12,11 +12,11 @@ from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_s
12
12
  def arrangement(input, mat1, mat2, beta, alpha, output):
13
13
  _, _, input_arranged = matmul.arrangement(mat1, mat2, input)
14
14
 
15
- mat1_arrange, mat2_arranged, output_arranged = matmul.arrangement(
15
+ mat1_arranged, mat2_arranged, output_arranged = matmul.arrangement(
16
16
  mat1, mat2, output
17
17
  )
18
18
 
19
- return input_arranged, mat1_arrange, mat2_arranged, beta, alpha, output_arranged
19
+ return input_arranged, mat1_arranged, mat2_arranged, beta, alpha, output_arranged
20
20
 
21
21
 
22
22
  def application(input, mat1, mat2, beta, alpha, output):
@@ -1,3 +1,4 @@
1
+ import pytest
1
2
  import torch
2
3
  import torch.nn.functional as F
3
4
 
@@ -34,7 +35,7 @@ def arrangement(q, k, v, o):
34
35
 
35
36
 
36
37
  def application(q, k, v, o):
37
- q_loaded = (q * 1.44269504089).to(ntl.float16)
38
+ q_loaded = (q * 1.44269504089).to(q.dtype)
38
39
 
39
40
  acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
40
41
  l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32)
@@ -42,13 +43,14 @@ def application(q, k, v, o):
42
43
 
43
44
  for i in range(k.shape[0]):
44
45
  qk = ntl.dot(q_loaded, ntl.trans(k[i]))
46
+ qk = ntl.where(k[i].offsets(-2) < k.source.shape[-2], qk, float("-inf"))
45
47
 
46
48
  m_ij = ntl.maximum(m_i, ntl.max(qk, 1))
47
49
  p = ntl.exp2(qk - m_ij[:, None])
48
50
  l_ij = ntl.sum(p, 1)
49
51
 
50
52
  alpha = ntl.exp2(m_i - m_ij)
51
- acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i])
53
+ acc = acc * alpha[:, None] + ntl.dot(p.to(v[i].dtype), v[i])
52
54
  m_i = m_ij
53
55
  l_i = l_i * alpha + l_ij
54
56
 
@@ -83,20 +85,31 @@ def attention(q, k, v):
83
85
 
84
86
  @skip_if_cuda_not_available
85
87
  class TestCUDA:
88
+ shapes = ((2, 4, 1024, 64), (2, 4, 1, 64))
89
+
86
90
  @classmethod
87
91
  def setup_class(cls):
88
92
  torch.manual_seed(0)
89
93
 
90
- shape = (2, 4, 1024, 64)
94
+ cls.args = {
95
+ shape: tuple(torch.randn(shape, device="cuda") for _ in range(3))
96
+ for shape in cls.shapes
97
+ }
98
+
99
+ @pytest.mark.parametrize("shape", shapes)
100
+ def test_fp32(self, shape):
101
+ q, k, v = (arg.to(torch.float32) for arg in type(self).args[shape])
91
102
 
92
- cls.q = torch.randn(shape, device="cuda")
93
- cls.k = torch.randn(shape, device="cuda")
94
- cls.v = torch.randn(shape, device="cuda")
103
+ assert torch.allclose(
104
+ attention(q, k, v),
105
+ F.scaled_dot_product_attention(q, k, v, scale=1),
106
+ atol=0.01,
107
+ rtol=0.01,
108
+ )
95
109
 
96
- def test_fp16(self):
97
- q = type(self).q.to(torch.float16)
98
- k = type(self).k.to(torch.float16)
99
- v = type(self).v.to(torch.float16)
110
+ @pytest.mark.parametrize("shape", shapes)
111
+ def test_fp16(self, shape):
112
+ q, k, v = (arg.to(torch.float16) for arg in type(self).args[shape])
100
113
 
101
114
  assert torch.allclose(
102
115
  attention(q, k, v),
@@ -0,0 +1,54 @@
1
+ import random
2
+
3
+ import torch
4
+
5
+ import ninetoothed
6
+ import ninetoothed.language as ntl
7
+ from ninetoothed import Tensor, block_size
8
+ from tests.skippers import skip_if_cuda_not_available
9
+
10
+
11
+ def arrangement(input, p, seed, output, BLOCK_SIZE=block_size()):
12
+ return input.tile((BLOCK_SIZE,)), p, seed, output.tile((BLOCK_SIZE,))
13
+
14
+
15
+ def application(input, p, seed, output):
16
+ output = ntl.where(ntl.rand(seed, input.offsets()) > p, input / (1 - p), 0) # noqa: F841
17
+
18
+
19
+ def dropout(input, p=0.5):
20
+ seed = random.randrange(0, 2**31)
21
+ output = torch.empty_like(input)
22
+
23
+ tensors = (Tensor(1), Tensor(0), Tensor(0), Tensor(1))
24
+ dropout_kernel = ninetoothed.make(arrangement, application, tensors)
25
+
26
+ dropout_kernel(input, p, seed, output)
27
+
28
+ return output
29
+
30
+
31
+ @skip_if_cuda_not_available
32
+ class TestCUDA:
33
+ @classmethod
34
+ def setup_class(cls):
35
+ random.seed(0)
36
+ torch.manual_seed(0)
37
+
38
+ size = 349
39
+
40
+ cls.input = torch.randn(size, device="cuda")
41
+
42
+ def test_fp16(self):
43
+ input = type(self).input.to(torch.float16)
44
+ p = 0.3
45
+
46
+ output = dropout(input, p=p)
47
+
48
+ assert input.shape == output.shape
49
+
50
+ non_zero_ratio = output.nonzero().numel() / input.numel()
51
+
52
+ assert abs(non_zero_ratio - (1 - p)) < 0.05
53
+
54
+ assert torch.allclose(output[output != 0], input[output != 0] / (1 - p))
@@ -0,0 +1,50 @@
1
+ import torch
2
+
3
+ import ninetoothed
4
+ from ninetoothed import Tensor, block_size
5
+ from ninetoothed.language import libdevice
6
+ from tests.skippers import skip_if_cuda_not_available
7
+
8
+
9
+ def arrangement(input, exponent, output, BLOCK_SIZE=block_size()):
10
+ return (
11
+ input.tile((BLOCK_SIZE,)),
12
+ exponent.tile((BLOCK_SIZE,)),
13
+ output.tile((BLOCK_SIZE,)),
14
+ )
15
+
16
+
17
+ def application(input, exponent, output):
18
+ output = libdevice.pow(input, exponent) # noqa: F841
19
+
20
+
21
+ def pow(input, exponent):
22
+ output = torch.empty_like(input)
23
+
24
+ pow_kernel = ninetoothed.make(
25
+ arrangement, application, (Tensor(1), Tensor(1), Tensor(1))
26
+ )
27
+
28
+ pow_kernel(input, exponent, output)
29
+
30
+ return output
31
+
32
+
33
+ @skip_if_cuda_not_available
34
+ class TestCUDA:
35
+ @classmethod
36
+ def setup_class(cls):
37
+ torch.manual_seed(0)
38
+
39
+ size = 44925
40
+
41
+ cls.input = torch.randn(size, device="cuda")
42
+ cls.exponent = torch.randn(size, device="cuda")
43
+
44
+ def test_fp32(self):
45
+ input = type(self).input.to(torch.float32)
46
+ exponent = type(self).exponent.to(torch.float32)
47
+
48
+ assert torch.allclose(
49
+ pow(input, exponent), torch.pow(input, exponent), equal_nan=True
50
+ )
File without changes
File without changes
File without changes
File without changes
File without changes