ninetoothed 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -483,13 +483,15 @@ class CodeGenerator(ast.NodeTransformer):
483
483
  mask = functools.reduce(
484
484
  lambda x, y: x & y,
485
485
  (
486
- offsets[source_dim][target_dim][
487
- type(self)._generate_slices(tensor, target_dim)
488
- ]
486
+ sum(
487
+ offsets[source_dim][target_dim][
488
+ type(self)._generate_slices(tensor, target_dim)
489
+ ]
490
+ for target_dim in range(tensor.target.ndim)
491
+ if offsets[source_dim][target_dim] != 0
492
+ )
489
493
  < tensor.source.shape[source_dim]
490
494
  for source_dim in range(tensor.source.ndim)
491
- for target_dim in range(tensor.target.ndim)
492
- if offsets[source_dim][target_dim] != 0
493
495
  ),
494
496
  ) & functools.reduce(
495
497
  lambda x, y: x & y,
@@ -505,13 +507,21 @@ class CodeGenerator(ast.NodeTransformer):
505
507
  return pointers, mask
506
508
 
507
509
  def _complete_indices(self, tensor, indices):
510
+ class _NextPowerOfTwoMaker(ast.NodeTransformer):
511
+ def visit_Name(self, node):
512
+ name = node.id
513
+
514
+ if not naming.is_meta(name):
515
+ next_power_of_2_name = naming.make_next_power_of_2(name)
516
+
517
+ return ast.Name(id=next_power_of_2_name, ctx=ast.Load())
518
+
519
+ return self.generic_visit(node)
520
+
508
521
  indices = list(self._generate_pid_indices(tensor) + tuple(indices))
509
522
 
510
523
  for size in tensor.innermost().shape:
511
- if Symbol.is_name(size):
512
- name = size.node.id
513
- if not naming.is_meta(name):
514
- size = naming.make_next_power_of_2(name)
524
+ size = _NextPowerOfTwoMaker().visit(Symbol(copy.deepcopy(size)).node)
515
525
 
516
526
  indices.append(call("arange", 0, size))
517
527
 
@@ -549,8 +559,10 @@ class CodeGenerator(ast.NodeTransformer):
549
559
 
550
560
  @staticmethod
551
561
  def _generate_offsets(tensor, indices):
552
- offsets = collections.defaultdict(
553
- lambda: collections.defaultdict(lambda: Symbol(0))
562
+ raw_offsets = collections.defaultdict(
563
+ lambda: collections.defaultdict(
564
+ lambda: collections.defaultdict(lambda: Symbol(0))
565
+ )
554
566
  )
555
567
 
556
568
  curr = tensor
@@ -560,36 +572,62 @@ class CodeGenerator(ast.NodeTransformer):
560
572
  stop = start + curr.ndim
561
573
  curr_indices = indices[start:stop]
562
574
 
563
- for index, stride, source_dim, target_dim in zip(
564
- curr_indices, curr.strides, curr.source_dims, curr.target_dims
575
+ for index, stride, source_dim, target_dim, unflattened_dim in zip(
576
+ curr_indices,
577
+ curr.strides,
578
+ curr.source_dims,
579
+ curr.target_dims,
580
+ curr.unflattened_dims,
565
581
  ):
566
- offsets[source_dim][target_dim] += index * stride
582
+ raw_offsets[source_dim][target_dim][unflattened_dim] += index * stride
567
583
 
568
584
  start = stop
569
585
  curr = curr.dtype
570
586
 
571
- for source_dim in tuple(offsets):
572
- for target_dim in tuple(offsets[source_dim]):
573
- if not isinstance(source_dim, tuple):
574
- continue
587
+ offsets = collections.defaultdict(
588
+ lambda: collections.defaultdict(lambda: Symbol(0))
589
+ )
575
590
 
576
- unraveled = CodeGenerator._unravel_index(
577
- offsets[source_dim][target_dim],
578
- tuple(tensor.source.shape[dim] for dim in source_dim),
579
- )
591
+ source_strides = tuple(Symbol(stride) for stride in tensor.source.strides)
580
592
 
581
- for offs, dim in zip(unraveled, source_dim):
582
- offsets[dim][target_dim] = offs
593
+ unflattened_strides = tuple(
594
+ Symbol(stride) for stride in tensor.unflattened.strides
595
+ )
583
596
 
584
- for source_dim in range(tensor.source.ndim):
585
- for target_dim in range(tensor.target.ndim):
586
- offsets[source_dim][target_dim] = copy.deepcopy(
587
- offsets[source_dim][target_dim]
597
+ def _add_unraveled_offsets(raw_offs, source_dim, target_dim, unflattened_dim):
598
+ if not isinstance(unflattened_dim, tuple):
599
+ offsets[source_dim][target_dim] += copy.deepcopy(
600
+ raw_offs
601
+ ).find_and_replace(
602
+ unflattened_strides, Symbol(1)
603
+ ) * unflattened_strides[unflattened_dim].find_and_replace(
604
+ source_strides, Symbol(1)
588
605
  )
589
- offsets[source_dim][target_dim].find_and_replace(
590
- Symbol(tensor.source.strides[source_dim]), Symbol(1)
606
+
607
+ return
608
+
609
+ unraveled_offs = CodeGenerator._unravel_index(
610
+ raw_offs,
611
+ tuple(tensor.unflattened.shape[dim] for dim in unflattened_dim),
612
+ )
613
+
614
+ for raw_offs, source_dim, unflattened_dim in zip(
615
+ unraveled_offs, source_dim, unflattened_dim
616
+ ):
617
+ _add_unraveled_offsets(
618
+ raw_offs, source_dim, target_dim, unflattened_dim
591
619
  )
592
620
 
621
+ for source_dim in tuple(raw_offsets):
622
+ for target_dim in tuple(raw_offsets[source_dim]):
623
+ for unflattened_dim in tuple(raw_offsets[source_dim][target_dim]):
624
+ _add_unraveled_offsets(
625
+ raw_offsets[source_dim][target_dim][unflattened_dim],
626
+ source_dim,
627
+ target_dim,
628
+ unflattened_dim,
629
+ )
630
+
593
631
  return offsets
594
632
 
595
633
  @staticmethod
ninetoothed/symbol.py CHANGED
@@ -140,7 +140,12 @@ class Symbol:
140
140
  return ast.unparse(self._node)
141
141
 
142
142
  def find_and_replace(self, target, replacement):
143
- _FindAndReplacer(target.node, replacement.node).visit(self._node)
143
+ if isinstance(target, tuple):
144
+ targets = tuple(item.node for item in target)
145
+ else:
146
+ targets = (target.node,)
147
+
148
+ return Symbol(_FindAndReplacer(targets, replacement.node).visit(self._node))
144
149
 
145
150
  def names(self):
146
151
  class NameCollector(ast.NodeVisitor):
@@ -175,12 +180,14 @@ class Symbol:
175
180
 
176
181
 
177
182
  class _FindAndReplacer(ast.NodeTransformer):
178
- def __init__(self, target, replacement):
179
- self._target_id = target.id
183
+ def __init__(self, targets, replacement):
184
+ self._targets_unparsed = tuple(
185
+ sorted({ast.unparse(target) for target in targets}, key=len, reverse=True)
186
+ )
180
187
  self._replacement = replacement
181
188
 
182
- def visit_Name(self, node):
183
- if node.id == self._target_id:
189
+ def visit(self, node):
190
+ if ast.unparse(node) in self._targets_unparsed:
184
191
  return self._replacement
185
192
 
186
- return self.generic_visit(node)
193
+ return super().visit(node)
ninetoothed/tensor.py CHANGED
@@ -3,7 +3,6 @@ import math
3
3
  import re
4
4
 
5
5
  import ninetoothed.naming as naming
6
- from ninetoothed.language import call
7
6
  from ninetoothed.symbol import Symbol
8
7
 
9
8
 
@@ -38,6 +37,8 @@ class Tensor:
38
37
  source_dims=None,
39
38
  target=None,
40
39
  target_dims=None,
40
+ unflattened=None,
41
+ unflattened_dims=None,
41
42
  ):
42
43
  self.dtype = dtype
43
44
 
@@ -82,6 +83,16 @@ class Tensor:
82
83
  else:
83
84
  self.target_dims = (dim for dim in range(self.target.ndim))
84
85
 
86
+ if unflattened is not None:
87
+ self.unflattened = unflattened
88
+ else:
89
+ self.unflattened = self
90
+
91
+ if unflattened_dims is not None:
92
+ self.unflattened_dims = unflattened_dims
93
+ else:
94
+ self.unflattened_dims = (dim for dim in range(self.unflattened.ndim))
95
+
85
96
  type(self).num_instances += 1
86
97
 
87
98
  def tile(self, tile_shape, strides=None, dilation=None):
@@ -113,8 +124,11 @@ class Tensor:
113
124
  if stride == -1:
114
125
  stride = tile_size
115
126
 
127
+ def cdiv(x, y):
128
+ return (x + y - 1) // y
129
+
116
130
  new_size = (
117
- call("cdiv", self_size - spacing * (tile_size - 1) - 1, stride) + 1
131
+ (cdiv(self_size - spacing * (tile_size - 1) - 1, stride) + 1)
118
132
  if stride != 0
119
133
  else -1
120
134
  )
@@ -135,10 +149,14 @@ class Tensor:
135
149
  strides=inner_strides,
136
150
  source=self.source,
137
151
  source_dims=self.source_dims,
152
+ unflattened=self.unflattened,
153
+ unflattened_dims=self.unflattened_dims,
138
154
  ),
139
155
  strides=outer_strides,
140
156
  source=self.source,
141
157
  source_dims=self.source_dims,
158
+ unflattened=self.unflattened,
159
+ unflattened_dims=self.unflattened_dims,
142
160
  )
143
161
 
144
162
  def expand(self, shape):
@@ -162,6 +180,8 @@ class Tensor:
162
180
  source=self.source,
163
181
  source_dims=self.source_dims,
164
182
  target_dims=self.target_dims,
183
+ unflattened=self.unflattened,
184
+ unflattened_dims=self.unflattened_dims,
165
185
  )
166
186
 
167
187
  def squeeze(self, dim):
@@ -190,6 +210,12 @@ class Tensor:
190
210
  for i, target_dim in enumerate(self.target_dims)
191
211
  if i not in dim
192
212
  ],
213
+ unflattened=self.unflattened,
214
+ unflattened_dims=[
215
+ unflattened_dim
216
+ for i, unflattened_dim in enumerate(self.unflattened_dims)
217
+ if i not in dim
218
+ ],
193
219
  )
194
220
 
195
221
  def permute(self, dims):
@@ -203,11 +229,13 @@ class Tensor:
203
229
  new_shape = [None for _ in range(self.ndim)]
204
230
  new_strides = [None for _ in range(self.ndim)]
205
231
  new_source_dims = [None for _ in range(self.ndim)]
232
+ new_unflattened_dims = [None for _ in range(self.ndim)]
206
233
 
207
234
  for original_dim, permuted_dim in enumerate(dims):
208
235
  new_shape[original_dim] = self.shape[permuted_dim]
209
236
  new_strides[original_dim] = self.strides[permuted_dim]
210
237
  new_source_dims[original_dim] = self.source_dims[permuted_dim]
238
+ new_unflattened_dims[original_dim] = self.unflattened_dims[permuted_dim]
211
239
 
212
240
  return type(self)(
213
241
  shape=new_shape,
@@ -216,6 +244,8 @@ class Tensor:
216
244
  source=self.source,
217
245
  source_dims=new_source_dims,
218
246
  target_dims=self.target_dims,
247
+ unflattened=self.unflattened,
248
+ unflattened_dims=new_unflattened_dims,
219
249
  )
220
250
 
221
251
  def flatten(self, start_dim=None, end_dim=None):
@@ -263,6 +293,16 @@ class Tensor:
263
293
  leading_target_dims + (flattening_target_dims[-1],) + trailing_target_dims
264
294
  )
265
295
 
296
+ leading_unflattened_dims = self.unflattened_dims[:start_dim]
297
+ flattening_unflattened_dims = self.unflattened_dims[start_dim:end_dim]
298
+ trailing_unflattened_dims = self.unflattened_dims[end_dim:]
299
+
300
+ new_unflattened_dims = (
301
+ leading_unflattened_dims
302
+ + (flattening_unflattened_dims,)
303
+ + trailing_unflattened_dims
304
+ )
305
+
266
306
  return type(self)(
267
307
  shape=new_shape,
268
308
  dtype=self.dtype,
@@ -270,6 +310,8 @@ class Tensor:
270
310
  source=self.source,
271
311
  source_dims=new_source_dims,
272
312
  target_dims=new_target_dims,
313
+ unflattened=self.unflattened,
314
+ unflattened_dims=new_unflattened_dims,
273
315
  )
274
316
 
275
317
  def ravel(self):
@@ -288,12 +330,14 @@ class Tensor:
288
330
  # TODO: Add error handling.
289
331
  new_shape = []
290
332
  new_strides = []
333
+ new_source_dims = []
291
334
 
292
335
  curr = self
293
336
 
294
337
  while isinstance(curr, type(self)):
295
338
  new_shape.extend(curr.shape)
296
339
  new_strides.extend(curr.strides)
340
+ new_source_dims.extend(curr.source_dims)
297
341
 
298
342
  curr = curr.dtype
299
343
 
@@ -302,6 +346,8 @@ class Tensor:
302
346
  strides=new_strides,
303
347
  other=self.source.other,
304
348
  name=self.source.name,
349
+ source=self.source,
350
+ source_dims=new_source_dims,
305
351
  )
306
352
 
307
353
  def names(self):
@@ -383,6 +429,14 @@ class Tensor:
383
429
  def target_dims(self, value):
384
430
  self._target_dims = tuple(value)
385
431
 
432
+ @property
433
+ def unflattened_dims(self):
434
+ return self._unflattened_dims
435
+
436
+ @unflattened_dims.setter
437
+ def unflattened_dims(self, value):
438
+ self._unflattened_dims = tuple(value)
439
+
386
440
  @staticmethod
387
441
  def pointer_pattern():
388
442
  return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
@@ -0,0 +1,128 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from mpl_toolkits.axes_grid1 import Divider, Size
4
+
5
+
6
+ def visualize(tensor, color=None, save_path=None):
7
+ """Visualize a tensor as a structured grid representation.
8
+
9
+ :param tensor: The tensor to be visualized.
10
+ :param color: The color to be used for visualization.
11
+ :param save_path: The path where the visualization should be saved.
12
+ """
13
+ outline_width = 0.1
14
+ plt.rcParams["lines.linewidth"] = 72 * outline_width
15
+
16
+ if color is None:
17
+ color = f"C{visualize.count}"
18
+
19
+ _, max_pos_x, max_pos_y = _visualize_tensor(plt.gca(), tensor, 0, 0, color)
20
+
21
+ width = max_pos_y + 1
22
+ height = max_pos_x + 1
23
+
24
+ fig = plt.figure(figsize=(width + outline_width, height + outline_width))
25
+
26
+ h = (Size.Fixed(0), Size.Fixed(width + outline_width))
27
+ v = (Size.Fixed(0), Size.Fixed(height + outline_width))
28
+
29
+ divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)
30
+
31
+ ax = fig.add_axes(
32
+ divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=1)
33
+ )
34
+
35
+ ax.set_aspect("equal")
36
+ ax.invert_yaxis()
37
+
38
+ plt.axis("off")
39
+
40
+ half_outline_width = outline_width / 2
41
+ plt.xlim((-half_outline_width, width + half_outline_width))
42
+ plt.ylim((-half_outline_width, height + half_outline_width))
43
+
44
+ _visualize_tensor(ax, tensor, 0, 0, color)
45
+
46
+ plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
47
+
48
+ plt.close()
49
+
50
+ visualize.count += 1
51
+
52
+
53
+ visualize.count = 0
54
+
55
+
56
+ def _visualize_tensor(ax, tensor, x, y, color, level_spacing=4):
57
+ verts = _visualize_level(ax, tensor, x, y, color)
58
+
59
+ if tensor.dtype is None:
60
+ return verts, verts[1][1][0], verts[1][1][1]
61
+
62
+ next_x, next_y = verts[0][1]
63
+ next_y += level_spacing + 1
64
+
65
+ next_verts, max_pos_x, max_pos_y = _visualize_tensor(
66
+ ax, tensor.dtype, next_x, next_y, color
67
+ )
68
+
69
+ conn_verts = _verts_of_rect(1, level_spacing, next_x, next_y - level_spacing)
70
+ conn_verts = [list(vert) for vert in conn_verts]
71
+ conn_verts[2][0] += next_verts[1][0][0]
72
+
73
+ pos_y, pos_x = zip(*conn_verts)
74
+ pos_x = pos_x + (pos_x[0],)
75
+ pos_y = pos_y + (pos_y[0],)
76
+
77
+ ax.plot(pos_x[1:3], pos_y[1:3], "k--")
78
+ ax.plot(pos_x[3:5], pos_y[3:5], "k--")
79
+
80
+ max_pos_x = max(max_pos_x, verts[1][1][0])
81
+ max_pos_y = max(max_pos_y, verts[1][1][1])
82
+
83
+ return verts, max_pos_x, max_pos_y
84
+
85
+
86
+ def _visualize_level(ax, level, x, y, color):
87
+ offsets = [1 for _ in range(level.ndim)]
88
+
89
+ for dim in range(-3, -level.ndim - 1, -1):
90
+ offsets[dim] = offsets[dim + 2] * level.shape[dim + 2] + 1
91
+
92
+ indices = np.indices(level.shape)
93
+ flattened_indices = np.stack(
94
+ [indices[i].flatten() for i in range(level.ndim)], axis=-1
95
+ )
96
+
97
+ max_pos_x = x
98
+ max_pos_y = y
99
+
100
+ for indices in flattened_indices:
101
+ pos = [x, y]
102
+
103
+ for dim, index in enumerate(indices):
104
+ pos[(level.ndim - dim) % 2] += index * offsets[dim]
105
+
106
+ max_pos_x = max(max_pos_x, pos[0])
107
+ max_pos_y = max(max_pos_y, pos[1])
108
+
109
+ _visualize_unit_square(ax, pos[1], pos[0], color)
110
+
111
+ verts = (((x, y), (x, max_pos_y)), ((max_pos_x, y), (max_pos_x, max_pos_y)))
112
+
113
+ return verts
114
+
115
+
116
+ def _visualize_unit_square(ax, x, y, color):
117
+ _visualize_rect(ax, 1, 1, x, y, color)
118
+
119
+
120
+ def _visualize_rect(ax, width, height, x, y, color):
121
+ pos_x, pos_y = zip(*_verts_of_rect(width, height, x, y))
122
+
123
+ ax.fill(pos_x, pos_y, color)
124
+ ax.plot(pos_x + (pos_x[0],), pos_y + (pos_y[0],), "k")
125
+
126
+
127
+ def _verts_of_rect(width, height, x, y):
128
+ return ((x, y), (x + width, y), (x + width, y + height), (x, y + height))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.11.1
3
+ Version: 0.13.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -11,6 +11,12 @@ Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
13
  Requires-Dist: triton>=3.0.0
14
+ Provides-Extra: all
15
+ Requires-Dist: matplotlib>=3.9.0; extra == 'all'
16
+ Requires-Dist: numpy>=2.1.0; extra == 'all'
17
+ Provides-Extra: visualization
18
+ Requires-Dist: matplotlib>=3.9.0; extra == 'visualization'
19
+ Requires-Dist: numpy>=2.1.0; extra == 'visualization'
14
20
  Description-Content-Type: text/markdown
15
21
 
16
22
  # NineToothed
@@ -0,0 +1,12 @@
1
+ ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
+ ninetoothed/jit.py,sha256=fuMfb4gEvz78n-7--GS1Ud14g33SLT-vI1kqt9BPAIQ,25759
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
+ ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
6
+ ninetoothed/tensor.py,sha256=Q7WPigNyGOgP5lYYac39pF6zlFbCyXYrICPgGOuuyr4,13976
7
+ ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
+ ninetoothed/visualization.py,sha256=IZ7iTT4dl5_JFbO-WfSWPFWpgkyPr4nylwhSZVy8gss,3601
9
+ ninetoothed-0.13.0.dist-info/METADATA,sha256=Vp_huL1YVgNAa9q6C0J5hZnCR7SRyMFRALGD21ikrjc,7311
10
+ ninetoothed-0.13.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
11
+ ninetoothed-0.13.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
12
+ ninetoothed-0.13.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
- ninetoothed/jit.py,sha256=U3Nen5vyx69ulW7_hnRuATW86Ag9NgVgd3U02NVB20c,24430
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
- ninetoothed/symbol.py,sha256=mN96tp-2eUxbiNfxuxtKWNSxOSdYqlcmpY2MYQ-FiEg,4993
6
- ninetoothed/tensor.py,sha256=Pgl3t08qNDJrjbNMLltEIwMu19vnKwMUncOq4aedTjY,11983
7
- ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
- ninetoothed-0.11.1.dist-info/METADATA,sha256=DsvP5HloDDgEiHUtPEu8blUVYCBr3_VXYILUikfOE-Y,7055
9
- ninetoothed-0.11.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- ninetoothed-0.11.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
- ninetoothed-0.11.1.dist-info/RECORD,,