ninetoothed 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/tensor.py CHANGED
@@ -3,7 +3,6 @@ import math
3
3
  import re
4
4
 
5
5
  import ninetoothed.naming as naming
6
- from ninetoothed.language import call
7
6
  from ninetoothed.symbol import Symbol
8
7
 
9
8
 
@@ -113,8 +112,11 @@ class Tensor:
113
112
  if stride == -1:
114
113
  stride = tile_size
115
114
 
115
+ def cdiv(x, y):
116
+ return (x + y - 1) // y
117
+
116
118
  new_size = (
117
- call("cdiv", self_size - spacing * (tile_size - 1) - 1, stride) + 1
119
+ (cdiv(self_size - spacing * (tile_size - 1) - 1, stride) + 1)
118
120
  if stride != 0
119
121
  else -1
120
122
  )
@@ -0,0 +1,122 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from mpl_toolkits.axes_grid1 import Divider, Size
4
+
5
+
6
+ def visualize(tensor, color=None, save_path=None):
7
+ outline_width = 0.1
8
+ plt.rcParams["lines.linewidth"] = 72 * outline_width
9
+
10
+ if color is None:
11
+ color = f"C{visualize.count}"
12
+
13
+ _, max_pos_x, max_pos_y = _visualize_tensor(plt.gca(), tensor, 0, 0, color)
14
+
15
+ width = max_pos_y + 1
16
+ height = max_pos_x + 1
17
+
18
+ fig = plt.figure(figsize=(width + outline_width, height + outline_width))
19
+
20
+ h = (Size.Fixed(0), Size.Fixed(width + outline_width))
21
+ v = (Size.Fixed(0), Size.Fixed(height + outline_width))
22
+
23
+ divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)
24
+
25
+ ax = fig.add_axes(
26
+ divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=1)
27
+ )
28
+
29
+ ax.set_aspect("equal")
30
+ ax.invert_yaxis()
31
+
32
+ plt.axis("off")
33
+
34
+ half_outline_width = outline_width / 2
35
+ plt.xlim((-half_outline_width, width + half_outline_width))
36
+ plt.ylim((-half_outline_width, height + half_outline_width))
37
+
38
+ _visualize_tensor(ax, tensor, 0, 0, color)
39
+
40
+ plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
41
+
42
+ plt.close()
43
+
44
+ visualize.count += 1
45
+
46
+
47
+ visualize.count = 0
48
+
49
+
50
+ def _visualize_tensor(ax, tensor, x, y, color, level_spacing=4):
51
+ verts = _visualize_level(ax, tensor, x, y, color)
52
+
53
+ if tensor.dtype is None:
54
+ return verts, verts[1][1][0], verts[1][1][1]
55
+
56
+ next_x, next_y = verts[0][1]
57
+ next_y += level_spacing + 1
58
+
59
+ next_verts, max_pos_x, max_pos_y = _visualize_tensor(
60
+ ax, tensor.dtype, next_x, next_y, color
61
+ )
62
+
63
+ conn_verts = _verts_of_rect(1, level_spacing, next_x, next_y - level_spacing)
64
+ conn_verts = [list(vert) for vert in conn_verts]
65
+ conn_verts[2][0] += next_verts[1][0][0]
66
+
67
+ pos_y, pos_x = zip(*conn_verts)
68
+ pos_x = pos_x + (pos_x[0],)
69
+ pos_y = pos_y + (pos_y[0],)
70
+
71
+ ax.plot(pos_x[1:3], pos_y[1:3], "k--")
72
+ ax.plot(pos_x[3:5], pos_y[3:5], "k--")
73
+
74
+ max_pos_x = max(max_pos_x, verts[1][1][0])
75
+ max_pos_y = max(max_pos_y, verts[1][1][1])
76
+
77
+ return verts, max_pos_x, max_pos_y
78
+
79
+
80
+ def _visualize_level(ax, level, x, y, color):
81
+ offsets = [1 for _ in range(level.ndim)]
82
+
83
+ for dim in range(-3, -level.ndim - 1, -1):
84
+ offsets[dim] = offsets[dim + 2] * level.shape[dim + 2] + 1
85
+
86
+ indices = np.indices(level.shape)
87
+ flattened_indices = np.stack(
88
+ [indices[i].flatten() for i in range(level.ndim)], axis=-1
89
+ )
90
+
91
+ max_pos_x = x
92
+ max_pos_y = y
93
+
94
+ for indices in flattened_indices:
95
+ pos = [x, y]
96
+
97
+ for dim, index in enumerate(indices):
98
+ pos[(level.ndim - dim) % 2] += index * offsets[dim]
99
+
100
+ max_pos_x = max(max_pos_x, pos[0])
101
+ max_pos_y = max(max_pos_y, pos[1])
102
+
103
+ _visualize_unit_square(ax, pos[1], pos[0], color)
104
+
105
+ verts = (((x, y), (x, max_pos_y)), ((max_pos_x, y), (max_pos_x, max_pos_y)))
106
+
107
+ return verts
108
+
109
+
110
+ def _visualize_unit_square(ax, x, y, color):
111
+ _visualize_rect(ax, 1, 1, x, y, color)
112
+
113
+
114
+ def _visualize_rect(ax, width, height, x, y, color):
115
+ pos_x, pos_y = zip(*_verts_of_rect(width, height, x, y))
116
+
117
+ ax.fill(pos_x, pos_y, color)
118
+ ax.plot(pos_x + (pos_x[0],), pos_y + (pos_y[0],), "k")
119
+
120
+
121
+ def _verts_of_rect(width, height, x, y):
122
+ return ((x, y), (x + width, y), (x + width, y + height), (x, y + height))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.11.1
3
+ Version: 0.12.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -11,6 +11,9 @@ Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
13
  Requires-Dist: triton>=3.0.0
14
+ Provides-Extra: visualization
15
+ Requires-Dist: matplotlib>=3.9.0; extra == 'visualization'
16
+ Requires-Dist: numpy>=2.1.0; extra == 'visualization'
14
17
  Description-Content-Type: text/markdown
15
18
 
16
19
  # NineToothed
@@ -3,9 +3,10 @@ ninetoothed/jit.py,sha256=U3Nen5vyx69ulW7_hnRuATW86Ag9NgVgd3U02NVB20c,24430
3
3
  ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
4
  ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
5
  ninetoothed/symbol.py,sha256=mN96tp-2eUxbiNfxuxtKWNSxOSdYqlcmpY2MYQ-FiEg,4993
6
- ninetoothed/tensor.py,sha256=Pgl3t08qNDJrjbNMLltEIwMu19vnKwMUncOq4aedTjY,11983
6
+ ninetoothed/tensor.py,sha256=E63sq3jh7ZLiLwFYTtavztKEZx7kRX-UVa2ZXSP2X0s,12008
7
7
  ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
- ninetoothed-0.11.1.dist-info/METADATA,sha256=DsvP5HloDDgEiHUtPEu8blUVYCBr3_VXYILUikfOE-Y,7055
9
- ninetoothed-0.11.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- ninetoothed-0.11.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
- ninetoothed-0.11.1.dist-info/RECORD,,
8
+ ninetoothed/visualization.py,sha256=VPPh__Bral_Z9hKj9D4UOo8HvRFQidCWSe9cS-D5QfY,3351
9
+ ninetoothed-0.12.0.dist-info/METADATA,sha256=gb5zeAxwYQRm-dFNmdX3uDIn_U-abEZ-VU9bSBBxioA,7198
10
+ ninetoothed-0.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
11
+ ninetoothed-0.12.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
12
+ ninetoothed-0.12.0.dist-info/RECORD,,