ninetoothed 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/tensor.py +4 -2
- ninetoothed/visualization.py +122 -0
- {ninetoothed-0.11.1.dist-info → ninetoothed-0.12.0.dist-info}/METADATA +4 -1
- {ninetoothed-0.11.1.dist-info → ninetoothed-0.12.0.dist-info}/RECORD +6 -5
- {ninetoothed-0.11.1.dist-info → ninetoothed-0.12.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.11.1.dist-info → ninetoothed-0.12.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/tensor.py
CHANGED
@@ -3,7 +3,6 @@ import math
|
|
3
3
|
import re
|
4
4
|
|
5
5
|
import ninetoothed.naming as naming
|
6
|
-
from ninetoothed.language import call
|
7
6
|
from ninetoothed.symbol import Symbol
|
8
7
|
|
9
8
|
|
@@ -113,8 +112,11 @@ class Tensor:
|
|
113
112
|
if stride == -1:
|
114
113
|
stride = tile_size
|
115
114
|
|
115
|
+
def cdiv(x, y):
|
116
|
+
return (x + y - 1) // y
|
117
|
+
|
116
118
|
new_size = (
|
117
|
-
|
119
|
+
(cdiv(self_size - spacing * (tile_size - 1) - 1, stride) + 1)
|
118
120
|
if stride != 0
|
119
121
|
else -1
|
120
122
|
)
|
@@ -0,0 +1,122 @@
|
|
1
|
+
import matplotlib.pyplot as plt
|
2
|
+
import numpy as np
|
3
|
+
from mpl_toolkits.axes_grid1 import Divider, Size
|
4
|
+
|
5
|
+
|
6
|
+
def visualize(tensor, color=None, save_path=None):
|
7
|
+
outline_width = 0.1
|
8
|
+
plt.rcParams["lines.linewidth"] = 72 * outline_width
|
9
|
+
|
10
|
+
if color is None:
|
11
|
+
color = f"C{visualize.count}"
|
12
|
+
|
13
|
+
_, max_pos_x, max_pos_y = _visualize_tensor(plt.gca(), tensor, 0, 0, color)
|
14
|
+
|
15
|
+
width = max_pos_y + 1
|
16
|
+
height = max_pos_x + 1
|
17
|
+
|
18
|
+
fig = plt.figure(figsize=(width + outline_width, height + outline_width))
|
19
|
+
|
20
|
+
h = (Size.Fixed(0), Size.Fixed(width + outline_width))
|
21
|
+
v = (Size.Fixed(0), Size.Fixed(height + outline_width))
|
22
|
+
|
23
|
+
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)
|
24
|
+
|
25
|
+
ax = fig.add_axes(
|
26
|
+
divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=1)
|
27
|
+
)
|
28
|
+
|
29
|
+
ax.set_aspect("equal")
|
30
|
+
ax.invert_yaxis()
|
31
|
+
|
32
|
+
plt.axis("off")
|
33
|
+
|
34
|
+
half_outline_width = outline_width / 2
|
35
|
+
plt.xlim((-half_outline_width, width + half_outline_width))
|
36
|
+
plt.ylim((-half_outline_width, height + half_outline_width))
|
37
|
+
|
38
|
+
_visualize_tensor(ax, tensor, 0, 0, color)
|
39
|
+
|
40
|
+
plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
|
41
|
+
|
42
|
+
plt.close()
|
43
|
+
|
44
|
+
visualize.count += 1
|
45
|
+
|
46
|
+
|
47
|
+
visualize.count = 0
|
48
|
+
|
49
|
+
|
50
|
+
def _visualize_tensor(ax, tensor, x, y, color, level_spacing=4):
|
51
|
+
verts = _visualize_level(ax, tensor, x, y, color)
|
52
|
+
|
53
|
+
if tensor.dtype is None:
|
54
|
+
return verts, verts[1][1][0], verts[1][1][1]
|
55
|
+
|
56
|
+
next_x, next_y = verts[0][1]
|
57
|
+
next_y += level_spacing + 1
|
58
|
+
|
59
|
+
next_verts, max_pos_x, max_pos_y = _visualize_tensor(
|
60
|
+
ax, tensor.dtype, next_x, next_y, color
|
61
|
+
)
|
62
|
+
|
63
|
+
conn_verts = _verts_of_rect(1, level_spacing, next_x, next_y - level_spacing)
|
64
|
+
conn_verts = [list(vert) for vert in conn_verts]
|
65
|
+
conn_verts[2][0] += next_verts[1][0][0]
|
66
|
+
|
67
|
+
pos_y, pos_x = zip(*conn_verts)
|
68
|
+
pos_x = pos_x + (pos_x[0],)
|
69
|
+
pos_y = pos_y + (pos_y[0],)
|
70
|
+
|
71
|
+
ax.plot(pos_x[1:3], pos_y[1:3], "k--")
|
72
|
+
ax.plot(pos_x[3:5], pos_y[3:5], "k--")
|
73
|
+
|
74
|
+
max_pos_x = max(max_pos_x, verts[1][1][0])
|
75
|
+
max_pos_y = max(max_pos_y, verts[1][1][1])
|
76
|
+
|
77
|
+
return verts, max_pos_x, max_pos_y
|
78
|
+
|
79
|
+
|
80
|
+
def _visualize_level(ax, level, x, y, color):
|
81
|
+
offsets = [1 for _ in range(level.ndim)]
|
82
|
+
|
83
|
+
for dim in range(-3, -level.ndim - 1, -1):
|
84
|
+
offsets[dim] = offsets[dim + 2] * level.shape[dim + 2] + 1
|
85
|
+
|
86
|
+
indices = np.indices(level.shape)
|
87
|
+
flattened_indices = np.stack(
|
88
|
+
[indices[i].flatten() for i in range(level.ndim)], axis=-1
|
89
|
+
)
|
90
|
+
|
91
|
+
max_pos_x = x
|
92
|
+
max_pos_y = y
|
93
|
+
|
94
|
+
for indices in flattened_indices:
|
95
|
+
pos = [x, y]
|
96
|
+
|
97
|
+
for dim, index in enumerate(indices):
|
98
|
+
pos[(level.ndim - dim) % 2] += index * offsets[dim]
|
99
|
+
|
100
|
+
max_pos_x = max(max_pos_x, pos[0])
|
101
|
+
max_pos_y = max(max_pos_y, pos[1])
|
102
|
+
|
103
|
+
_visualize_unit_square(ax, pos[1], pos[0], color)
|
104
|
+
|
105
|
+
verts = (((x, y), (x, max_pos_y)), ((max_pos_x, y), (max_pos_x, max_pos_y)))
|
106
|
+
|
107
|
+
return verts
|
108
|
+
|
109
|
+
|
110
|
+
def _visualize_unit_square(ax, x, y, color):
|
111
|
+
_visualize_rect(ax, 1, 1, x, y, color)
|
112
|
+
|
113
|
+
|
114
|
+
def _visualize_rect(ax, width, height, x, y, color):
|
115
|
+
pos_x, pos_y = zip(*_verts_of_rect(width, height, x, y))
|
116
|
+
|
117
|
+
ax.fill(pos_x, pos_y, color)
|
118
|
+
ax.plot(pos_x + (pos_x[0],), pos_y + (pos_y[0],), "k")
|
119
|
+
|
120
|
+
|
121
|
+
def _verts_of_rect(width, height, x, y):
|
122
|
+
return ((x, y), (x + width, y), (x + width, y + height), (x, y + height))
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.12.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -11,6 +11,9 @@ Classifier: Operating System :: OS Independent
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
13
|
Requires-Dist: triton>=3.0.0
|
14
|
+
Provides-Extra: visualization
|
15
|
+
Requires-Dist: matplotlib>=3.9.0; extra == 'visualization'
|
16
|
+
Requires-Dist: numpy>=2.1.0; extra == 'visualization'
|
14
17
|
Description-Content-Type: text/markdown
|
15
18
|
|
16
19
|
# NineToothed
|
@@ -3,9 +3,10 @@ ninetoothed/jit.py,sha256=U3Nen5vyx69ulW7_hnRuATW86Ag9NgVgd3U02NVB20c,24430
|
|
3
3
|
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
4
|
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
5
|
ninetoothed/symbol.py,sha256=mN96tp-2eUxbiNfxuxtKWNSxOSdYqlcmpY2MYQ-FiEg,4993
|
6
|
-
ninetoothed/tensor.py,sha256=
|
6
|
+
ninetoothed/tensor.py,sha256=E63sq3jh7ZLiLwFYTtavztKEZx7kRX-UVa2ZXSP2X0s,12008
|
7
7
|
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
-
ninetoothed
|
9
|
-
ninetoothed-0.
|
10
|
-
ninetoothed-0.
|
11
|
-
ninetoothed-0.
|
8
|
+
ninetoothed/visualization.py,sha256=VPPh__Bral_Z9hKj9D4UOo8HvRFQidCWSe9cS-D5QfY,3351
|
9
|
+
ninetoothed-0.12.0.dist-info/METADATA,sha256=gb5zeAxwYQRm-dFNmdX3uDIn_U-abEZ-VU9bSBBxioA,7198
|
10
|
+
ninetoothed-0.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
11
|
+
ninetoothed-0.12.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
12
|
+
ninetoothed-0.12.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|