ninetoothed 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +31 -3
- ninetoothed/symbol.py +7 -0
- ninetoothed/tensor.py +67 -2
- ninetoothed/visualization.py +122 -0
- {ninetoothed-0.11.0.dist-info → ninetoothed-0.12.0.dist-info}/METADATA +4 -1
- ninetoothed-0.12.0.dist-info/RECORD +12 -0
- ninetoothed-0.11.0.dist-info/RECORD +0 -11
- {ninetoothed-0.11.0.dist-info → ninetoothed-0.12.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.11.0.dist-info → ninetoothed-0.12.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -20,6 +20,13 @@ from ninetoothed.torchifier import Torchifier
|
|
20
20
|
|
21
21
|
|
22
22
|
def make(arrangement, application, tensors):
|
23
|
+
"""Integrate the arrangement and the application of the tensors.
|
24
|
+
|
25
|
+
:param arrangement: The arrangement of the tensors.
|
26
|
+
:param application: The application of the tensors.
|
27
|
+
:param tensors: The tensors.
|
28
|
+
:return: A handle to the compute kernel.
|
29
|
+
"""
|
23
30
|
params = inspect.signature(application).parameters
|
24
31
|
types = arrangement(*tensors)
|
25
32
|
annotations = {param: type for param, type in zip(params, types)}
|
@@ -28,14 +35,26 @@ def make(arrangement, application, tensors):
|
|
28
35
|
return jit(application)
|
29
36
|
|
30
37
|
|
31
|
-
def jit(
|
38
|
+
def jit(func=None, *, _prettify=False):
|
39
|
+
"""A decorator for generating compute kernels.
|
40
|
+
|
41
|
+
:param func: The function to be compiled.
|
42
|
+
:param _prettify: Whether to prettify the generated code.
|
43
|
+
:return: A handle to the compute kernel.
|
44
|
+
|
45
|
+
.. note::
|
46
|
+
|
47
|
+
The ``_prettify`` parameter is experimental, which might break
|
48
|
+
the generated code.
|
49
|
+
"""
|
50
|
+
|
32
51
|
def wrapper(func):
|
33
52
|
return JIT(func, _prettify=_prettify)()
|
34
53
|
|
35
|
-
if
|
54
|
+
if func is None:
|
36
55
|
return wrapper
|
37
56
|
|
38
|
-
return wrapper(
|
57
|
+
return wrapper(func)
|
39
58
|
|
40
59
|
|
41
60
|
class JIT:
|
@@ -472,6 +491,15 @@ class CodeGenerator(ast.NodeTransformer):
|
|
472
491
|
for target_dim in range(tensor.target.ndim)
|
473
492
|
if offsets[source_dim][target_dim] != 0
|
474
493
|
),
|
494
|
+
) & functools.reduce(
|
495
|
+
lambda x, y: x & y,
|
496
|
+
(
|
497
|
+
indices[dim - tensor.innermost().target.ndim][
|
498
|
+
type(self)._generate_slices(tensor, target_dim)
|
499
|
+
]
|
500
|
+
< tensor.innermost().target.shape[dim]
|
501
|
+
for dim, target_dim in enumerate(tensor.innermost().target_dims)
|
502
|
+
),
|
475
503
|
)
|
476
504
|
|
477
505
|
return pointers, mask
|
ninetoothed/symbol.py
CHANGED
@@ -7,6 +7,13 @@ import ninetoothed.naming as naming
|
|
7
7
|
|
8
8
|
|
9
9
|
class Symbol:
|
10
|
+
"""A class uesed to represent a symbol.
|
11
|
+
|
12
|
+
:param expr: The expression used to construct the symbol.
|
13
|
+
:param constexpr: Whether the symbol is a constexpr.
|
14
|
+
:param mata: Whether the symbol is a meta.
|
15
|
+
"""
|
16
|
+
|
10
17
|
def __init__(self, expr, constexpr=None, meta=None):
|
11
18
|
if isinstance(expr, type(self)):
|
12
19
|
self._node = expr._node
|
ninetoothed/tensor.py
CHANGED
@@ -3,11 +3,25 @@ import math
|
|
3
3
|
import re
|
4
4
|
|
5
5
|
import ninetoothed.naming as naming
|
6
|
-
from ninetoothed.language import call
|
7
6
|
from ninetoothed.symbol import Symbol
|
8
7
|
|
9
8
|
|
10
9
|
class Tensor:
|
10
|
+
"""A class uesed to represent a symbolic tensor.
|
11
|
+
|
12
|
+
:param ndim: The number of dimensions of the tensor.
|
13
|
+
:param shape: The shape of the tensor.
|
14
|
+
:param dtype: The element type of the tensor.
|
15
|
+
:param strides: The strides of the tensor.
|
16
|
+
:param other: The values for out-of-bounds positions.
|
17
|
+
:param constexpr_shape: Whether the sizes are constexpr.
|
18
|
+
:param name: The name of the tensor.
|
19
|
+
:param source: For internal use only.
|
20
|
+
:param source_dims: For internal use only.
|
21
|
+
:param target: For internal use only.
|
22
|
+
:param target_dims: For internal use only.
|
23
|
+
"""
|
24
|
+
|
11
25
|
num_instances = 0
|
12
26
|
|
13
27
|
def __init__(
|
@@ -70,6 +84,14 @@ class Tensor:
|
|
70
84
|
type(self).num_instances += 1
|
71
85
|
|
72
86
|
def tile(self, tile_shape, strides=None, dilation=None):
|
87
|
+
"""Tiles the tensor into a hierarchical tensor.
|
88
|
+
|
89
|
+
:param tile_shape: The shape of a tile.
|
90
|
+
:param strides: The interval at which each tile is generated.
|
91
|
+
:param dilation: The spacing between tiles.
|
92
|
+
:return: A hierarchical tensor.
|
93
|
+
"""
|
94
|
+
|
73
95
|
if strides is None:
|
74
96
|
strides = [-1 for _ in tile_shape]
|
75
97
|
|
@@ -90,8 +112,11 @@ class Tensor:
|
|
90
112
|
if stride == -1:
|
91
113
|
stride = tile_size
|
92
114
|
|
115
|
+
def cdiv(x, y):
|
116
|
+
return (x + y - 1) // y
|
117
|
+
|
93
118
|
new_size = (
|
94
|
-
|
119
|
+
(cdiv(self_size - spacing * (tile_size - 1) - 1, stride) + 1)
|
95
120
|
if stride != 0
|
96
121
|
else -1
|
97
122
|
)
|
@@ -119,6 +144,12 @@ class Tensor:
|
|
119
144
|
)
|
120
145
|
|
121
146
|
def expand(self, shape):
|
147
|
+
"""Expands the specified singleton dimensions of the tensor.
|
148
|
+
|
149
|
+
:param shape: The expanded shape.
|
150
|
+
:return: The expanded tensor.
|
151
|
+
"""
|
152
|
+
|
122
153
|
# TODO: Add error handling.
|
123
154
|
return type(self)(
|
124
155
|
shape=[
|
@@ -136,6 +167,12 @@ class Tensor:
|
|
136
167
|
)
|
137
168
|
|
138
169
|
def squeeze(self, dim):
|
170
|
+
"""Removes the specified singleton dimensions of the tensor.
|
171
|
+
|
172
|
+
:param dim: The dimension(s) to be squeezed.
|
173
|
+
:return: The squeezed tensor.
|
174
|
+
"""
|
175
|
+
|
139
176
|
if not isinstance(dim, tuple):
|
140
177
|
dim = (dim,)
|
141
178
|
|
@@ -158,6 +195,12 @@ class Tensor:
|
|
158
195
|
)
|
159
196
|
|
160
197
|
def permute(self, dims):
|
198
|
+
"""Permutes the dimensions of the tensor.
|
199
|
+
|
200
|
+
:param dims: The permuted ordering of the dimensions.
|
201
|
+
:return: The permuted tensor.
|
202
|
+
"""
|
203
|
+
|
161
204
|
# TODO: Add error handling.
|
162
205
|
new_shape = [None for _ in range(self.ndim)]
|
163
206
|
new_strides = [None for _ in range(self.ndim)]
|
@@ -178,6 +221,16 @@ class Tensor:
|
|
178
221
|
)
|
179
222
|
|
180
223
|
def flatten(self, start_dim=None, end_dim=None):
|
224
|
+
"""Flattens the specified dimensions of the tensor.
|
225
|
+
|
226
|
+
See :func:`ravel` for the differences between :func:`flatten`
|
227
|
+
and :func:`ravel`.
|
228
|
+
|
229
|
+
:param start_dim: The first dimension to flatten.
|
230
|
+
:param end_dim: The dimension after the last to flatten.
|
231
|
+
:return: The flattened tensor.
|
232
|
+
"""
|
233
|
+
|
181
234
|
# TODO: Add error handling.
|
182
235
|
if start_dim is None:
|
183
236
|
start_dim = 0
|
@@ -222,6 +275,18 @@ class Tensor:
|
|
222
275
|
)
|
223
276
|
|
224
277
|
def ravel(self):
|
278
|
+
"""Flattens the hierarchy of the tensor.
|
279
|
+
|
280
|
+
:func:`ravel` differs from :func:`flatten`, which only flattens
|
281
|
+
dimensions at a single level. For example, consider a tensor
|
282
|
+
with two levels: the first level has a shape of ``(N, P, Q)``,
|
283
|
+
and the second level has a shape of ``(C, R, S)``. After
|
284
|
+
applying :func:`ravel`, the resulting tensor will have a single
|
285
|
+
flattened level with a shape of ``(N, P, Q, C, R, S)``.
|
286
|
+
|
287
|
+
:return: The raveled tensor.
|
288
|
+
"""
|
289
|
+
|
225
290
|
# TODO: Add error handling.
|
226
291
|
new_shape = []
|
227
292
|
new_strides = []
|
@@ -0,0 +1,122 @@
|
|
1
|
+
import matplotlib.pyplot as plt
|
2
|
+
import numpy as np
|
3
|
+
from mpl_toolkits.axes_grid1 import Divider, Size
|
4
|
+
|
5
|
+
|
6
|
+
def visualize(tensor, color=None, save_path=None):
|
7
|
+
outline_width = 0.1
|
8
|
+
plt.rcParams["lines.linewidth"] = 72 * outline_width
|
9
|
+
|
10
|
+
if color is None:
|
11
|
+
color = f"C{visualize.count}"
|
12
|
+
|
13
|
+
_, max_pos_x, max_pos_y = _visualize_tensor(plt.gca(), tensor, 0, 0, color)
|
14
|
+
|
15
|
+
width = max_pos_y + 1
|
16
|
+
height = max_pos_x + 1
|
17
|
+
|
18
|
+
fig = plt.figure(figsize=(width + outline_width, height + outline_width))
|
19
|
+
|
20
|
+
h = (Size.Fixed(0), Size.Fixed(width + outline_width))
|
21
|
+
v = (Size.Fixed(0), Size.Fixed(height + outline_width))
|
22
|
+
|
23
|
+
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)
|
24
|
+
|
25
|
+
ax = fig.add_axes(
|
26
|
+
divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=1)
|
27
|
+
)
|
28
|
+
|
29
|
+
ax.set_aspect("equal")
|
30
|
+
ax.invert_yaxis()
|
31
|
+
|
32
|
+
plt.axis("off")
|
33
|
+
|
34
|
+
half_outline_width = outline_width / 2
|
35
|
+
plt.xlim((-half_outline_width, width + half_outline_width))
|
36
|
+
plt.ylim((-half_outline_width, height + half_outline_width))
|
37
|
+
|
38
|
+
_visualize_tensor(ax, tensor, 0, 0, color)
|
39
|
+
|
40
|
+
plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
|
41
|
+
|
42
|
+
plt.close()
|
43
|
+
|
44
|
+
visualize.count += 1
|
45
|
+
|
46
|
+
|
47
|
+
visualize.count = 0
|
48
|
+
|
49
|
+
|
50
|
+
def _visualize_tensor(ax, tensor, x, y, color, level_spacing=4):
|
51
|
+
verts = _visualize_level(ax, tensor, x, y, color)
|
52
|
+
|
53
|
+
if tensor.dtype is None:
|
54
|
+
return verts, verts[1][1][0], verts[1][1][1]
|
55
|
+
|
56
|
+
next_x, next_y = verts[0][1]
|
57
|
+
next_y += level_spacing + 1
|
58
|
+
|
59
|
+
next_verts, max_pos_x, max_pos_y = _visualize_tensor(
|
60
|
+
ax, tensor.dtype, next_x, next_y, color
|
61
|
+
)
|
62
|
+
|
63
|
+
conn_verts = _verts_of_rect(1, level_spacing, next_x, next_y - level_spacing)
|
64
|
+
conn_verts = [list(vert) for vert in conn_verts]
|
65
|
+
conn_verts[2][0] += next_verts[1][0][0]
|
66
|
+
|
67
|
+
pos_y, pos_x = zip(*conn_verts)
|
68
|
+
pos_x = pos_x + (pos_x[0],)
|
69
|
+
pos_y = pos_y + (pos_y[0],)
|
70
|
+
|
71
|
+
ax.plot(pos_x[1:3], pos_y[1:3], "k--")
|
72
|
+
ax.plot(pos_x[3:5], pos_y[3:5], "k--")
|
73
|
+
|
74
|
+
max_pos_x = max(max_pos_x, verts[1][1][0])
|
75
|
+
max_pos_y = max(max_pos_y, verts[1][1][1])
|
76
|
+
|
77
|
+
return verts, max_pos_x, max_pos_y
|
78
|
+
|
79
|
+
|
80
|
+
def _visualize_level(ax, level, x, y, color):
|
81
|
+
offsets = [1 for _ in range(level.ndim)]
|
82
|
+
|
83
|
+
for dim in range(-3, -level.ndim - 1, -1):
|
84
|
+
offsets[dim] = offsets[dim + 2] * level.shape[dim + 2] + 1
|
85
|
+
|
86
|
+
indices = np.indices(level.shape)
|
87
|
+
flattened_indices = np.stack(
|
88
|
+
[indices[i].flatten() for i in range(level.ndim)], axis=-1
|
89
|
+
)
|
90
|
+
|
91
|
+
max_pos_x = x
|
92
|
+
max_pos_y = y
|
93
|
+
|
94
|
+
for indices in flattened_indices:
|
95
|
+
pos = [x, y]
|
96
|
+
|
97
|
+
for dim, index in enumerate(indices):
|
98
|
+
pos[(level.ndim - dim) % 2] += index * offsets[dim]
|
99
|
+
|
100
|
+
max_pos_x = max(max_pos_x, pos[0])
|
101
|
+
max_pos_y = max(max_pos_y, pos[1])
|
102
|
+
|
103
|
+
_visualize_unit_square(ax, pos[1], pos[0], color)
|
104
|
+
|
105
|
+
verts = (((x, y), (x, max_pos_y)), ((max_pos_x, y), (max_pos_x, max_pos_y)))
|
106
|
+
|
107
|
+
return verts
|
108
|
+
|
109
|
+
|
110
|
+
def _visualize_unit_square(ax, x, y, color):
|
111
|
+
_visualize_rect(ax, 1, 1, x, y, color)
|
112
|
+
|
113
|
+
|
114
|
+
def _visualize_rect(ax, width, height, x, y, color):
|
115
|
+
pos_x, pos_y = zip(*_verts_of_rect(width, height, x, y))
|
116
|
+
|
117
|
+
ax.fill(pos_x, pos_y, color)
|
118
|
+
ax.plot(pos_x + (pos_x[0],), pos_y + (pos_y[0],), "k")
|
119
|
+
|
120
|
+
|
121
|
+
def _verts_of_rect(width, height, x, y):
|
122
|
+
return ((x, y), (x + width, y), (x + width, y + height), (x, y + height))
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.12.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -11,6 +11,9 @@ Classifier: Operating System :: OS Independent
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
13
|
Requires-Dist: triton>=3.0.0
|
14
|
+
Provides-Extra: visualization
|
15
|
+
Requires-Dist: matplotlib>=3.9.0; extra == 'visualization'
|
16
|
+
Requires-Dist: numpy>=2.1.0; extra == 'visualization'
|
14
17
|
Description-Content-Type: text/markdown
|
15
18
|
|
16
19
|
# NineToothed
|
@@ -0,0 +1,12 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
+
ninetoothed/jit.py,sha256=U3Nen5vyx69ulW7_hnRuATW86Ag9NgVgd3U02NVB20c,24430
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
+
ninetoothed/symbol.py,sha256=mN96tp-2eUxbiNfxuxtKWNSxOSdYqlcmpY2MYQ-FiEg,4993
|
6
|
+
ninetoothed/tensor.py,sha256=E63sq3jh7ZLiLwFYTtavztKEZx7kRX-UVa2ZXSP2X0s,12008
|
7
|
+
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
+
ninetoothed/visualization.py,sha256=VPPh__Bral_Z9hKj9D4UOo8HvRFQidCWSe9cS-D5QfY,3351
|
9
|
+
ninetoothed-0.12.0.dist-info/METADATA,sha256=gb5zeAxwYQRm-dFNmdX3uDIn_U-abEZ-VU9bSBBxioA,7198
|
10
|
+
ninetoothed-0.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
11
|
+
ninetoothed-0.12.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
12
|
+
ninetoothed-0.12.0.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
-
ninetoothed/jit.py,sha256=0LeDBpSYFgPx4hatP_ZsvElsj0d9d552OKRc__L1Jvc,23460
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
-
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
-
ninetoothed/tensor.py,sha256=ehgbHrxL1gc9iOI9rGUp5k-nI_daLmCfOdLI9hE-GLw,9756
|
7
|
-
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
-
ninetoothed-0.11.0.dist-info/METADATA,sha256=NlC4oj1R7gNoCdA1AxYc_7UWY3uZPv3LeOq9Jo-9Q8w,7055
|
9
|
-
ninetoothed-0.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
-
ninetoothed-0.11.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
-
ninetoothed-0.11.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|