ninetoothed 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -483,13 +483,15 @@ class CodeGenerator(ast.NodeTransformer):
483
483
  mask = functools.reduce(
484
484
  lambda x, y: x & y,
485
485
  (
486
- offsets[source_dim][target_dim][
487
- type(self)._generate_slices(tensor, target_dim)
488
- ]
486
+ sum(
487
+ offsets[source_dim][target_dim][
488
+ type(self)._generate_slices(tensor, target_dim)
489
+ ]
490
+ for target_dim in range(tensor.target.ndim)
491
+ if offsets[source_dim][target_dim] != 0
492
+ )
489
493
  < tensor.source.shape[source_dim]
490
494
  for source_dim in range(tensor.source.ndim)
491
- for target_dim in range(tensor.target.ndim)
492
- if offsets[source_dim][target_dim] != 0
493
495
  ),
494
496
  ) & functools.reduce(
495
497
  lambda x, y: x & y,
@@ -505,13 +507,21 @@ class CodeGenerator(ast.NodeTransformer):
505
507
  return pointers, mask
506
508
 
507
509
  def _complete_indices(self, tensor, indices):
510
+ class _NextPowerOfTwoMaker(ast.NodeTransformer):
511
+ def visit_Name(self, node):
512
+ name = node.id
513
+
514
+ if not naming.is_meta(name):
515
+ next_power_of_2_name = naming.make_next_power_of_2(name)
516
+
517
+ return ast.Name(id=next_power_of_2_name, ctx=ast.Load())
518
+
519
+ return self.generic_visit(node)
520
+
508
521
  indices = list(self._generate_pid_indices(tensor) + tuple(indices))
509
522
 
510
523
  for size in tensor.innermost().shape:
511
- if Symbol.is_name(size):
512
- name = size.node.id
513
- if not naming.is_meta(name):
514
- size = naming.make_next_power_of_2(name)
524
+ size = _NextPowerOfTwoMaker().visit(Symbol(copy.deepcopy(size)).node)
515
525
 
516
526
  indices.append(call("arange", 0, size))
517
527
 
@@ -549,8 +559,10 @@ class CodeGenerator(ast.NodeTransformer):
549
559
 
550
560
  @staticmethod
551
561
  def _generate_offsets(tensor, indices):
552
- offsets = collections.defaultdict(
553
- lambda: collections.defaultdict(lambda: Symbol(0))
562
+ raw_offsets = collections.defaultdict(
563
+ lambda: collections.defaultdict(
564
+ lambda: collections.defaultdict(lambda: Symbol(0))
565
+ )
554
566
  )
555
567
 
556
568
  curr = tensor
@@ -560,36 +572,62 @@ class CodeGenerator(ast.NodeTransformer):
560
572
  stop = start + curr.ndim
561
573
  curr_indices = indices[start:stop]
562
574
 
563
- for index, stride, source_dim, target_dim in zip(
564
- curr_indices, curr.strides, curr.source_dims, curr.target_dims
575
+ for index, stride, source_dim, target_dim, unflattened_dim in zip(
576
+ curr_indices,
577
+ curr.strides,
578
+ curr.source_dims,
579
+ curr.target_dims,
580
+ curr.unflattened_dims,
565
581
  ):
566
- offsets[source_dim][target_dim] += index * stride
582
+ raw_offsets[source_dim][target_dim][unflattened_dim] += index * stride
567
583
 
568
584
  start = stop
569
585
  curr = curr.dtype
570
586
 
571
- for source_dim in tuple(offsets):
572
- for target_dim in tuple(offsets[source_dim]):
573
- if not isinstance(source_dim, tuple):
574
- continue
587
+ offsets = collections.defaultdict(
588
+ lambda: collections.defaultdict(lambda: Symbol(0))
589
+ )
575
590
 
576
- unraveled = CodeGenerator._unravel_index(
577
- offsets[source_dim][target_dim],
578
- tuple(tensor.source.shape[dim] for dim in source_dim),
579
- )
591
+ source_strides = tuple(Symbol(stride) for stride in tensor.source.strides)
580
592
 
581
- for offs, dim in zip(unraveled, source_dim):
582
- offsets[dim][target_dim] = offs
593
+ unflattened_strides = tuple(
594
+ Symbol(stride) for stride in tensor.unflattened.strides
595
+ )
583
596
 
584
- for source_dim in range(tensor.source.ndim):
585
- for target_dim in range(tensor.target.ndim):
586
- offsets[source_dim][target_dim] = copy.deepcopy(
587
- offsets[source_dim][target_dim]
597
+ def _add_unraveled_offsets(raw_offs, source_dim, target_dim, unflattened_dim):
598
+ if not isinstance(unflattened_dim, tuple):
599
+ offsets[source_dim][target_dim] += copy.deepcopy(
600
+ raw_offs
601
+ ).find_and_replace(
602
+ unflattened_strides, Symbol(1)
603
+ ) * unflattened_strides[unflattened_dim].find_and_replace(
604
+ source_strides, Symbol(1)
588
605
  )
589
- offsets[source_dim][target_dim].find_and_replace(
590
- Symbol(tensor.source.strides[source_dim]), Symbol(1)
606
+
607
+ return
608
+
609
+ unraveled_offs = CodeGenerator._unravel_index(
610
+ raw_offs,
611
+ tuple(tensor.unflattened.shape[dim] for dim in unflattened_dim),
612
+ )
613
+
614
+ for raw_offs, source_dim, unflattened_dim in zip(
615
+ unraveled_offs, source_dim, unflattened_dim
616
+ ):
617
+ _add_unraveled_offsets(
618
+ raw_offs, source_dim, target_dim, unflattened_dim
591
619
  )
592
620
 
621
+ for source_dim in tuple(raw_offsets):
622
+ for target_dim in tuple(raw_offsets[source_dim]):
623
+ for unflattened_dim in tuple(raw_offsets[source_dim][target_dim]):
624
+ _add_unraveled_offsets(
625
+ raw_offsets[source_dim][target_dim][unflattened_dim],
626
+ source_dim,
627
+ target_dim,
628
+ unflattened_dim,
629
+ )
630
+
593
631
  return offsets
594
632
 
595
633
  @staticmethod
ninetoothed/symbol.py CHANGED
@@ -140,7 +140,12 @@ class Symbol:
140
140
  return ast.unparse(self._node)
141
141
 
142
142
  def find_and_replace(self, target, replacement):
143
- _FindAndReplacer(target.node, replacement.node).visit(self._node)
143
+ if isinstance(target, tuple):
144
+ targets = tuple(item.node for item in target)
145
+ else:
146
+ targets = (target.node,)
147
+
148
+ return Symbol(_FindAndReplacer(targets, replacement.node).visit(self._node))
144
149
 
145
150
  def names(self):
146
151
  class NameCollector(ast.NodeVisitor):
@@ -175,12 +180,14 @@ class Symbol:
175
180
 
176
181
 
177
182
  class _FindAndReplacer(ast.NodeTransformer):
178
- def __init__(self, target, replacement):
179
- self._target_id = target.id
183
+ def __init__(self, targets, replacement):
184
+ self._targets_unparsed = tuple(
185
+ sorted({ast.unparse(target) for target in targets}, key=len, reverse=True)
186
+ )
180
187
  self._replacement = replacement
181
188
 
182
- def visit_Name(self, node):
183
- if node.id == self._target_id:
189
+ def visit(self, node):
190
+ if ast.unparse(node) in self._targets_unparsed:
184
191
  return self._replacement
185
192
 
186
- return self.generic_visit(node)
193
+ return super().visit(node)
ninetoothed/tensor.py CHANGED
@@ -37,6 +37,8 @@ class Tensor:
37
37
  source_dims=None,
38
38
  target=None,
39
39
  target_dims=None,
40
+ unflattened=None,
41
+ unflattened_dims=None,
40
42
  ):
41
43
  self.dtype = dtype
42
44
 
@@ -81,6 +83,16 @@ class Tensor:
81
83
  else:
82
84
  self.target_dims = (dim for dim in range(self.target.ndim))
83
85
 
86
+ if unflattened is not None:
87
+ self.unflattened = unflattened
88
+ else:
89
+ self.unflattened = self
90
+
91
+ if unflattened_dims is not None:
92
+ self.unflattened_dims = unflattened_dims
93
+ else:
94
+ self.unflattened_dims = (dim for dim in range(self.unflattened.ndim))
95
+
84
96
  type(self).num_instances += 1
85
97
 
86
98
  def tile(self, tile_shape, strides=None, dilation=None):
@@ -137,10 +149,14 @@ class Tensor:
137
149
  strides=inner_strides,
138
150
  source=self.source,
139
151
  source_dims=self.source_dims,
152
+ unflattened=self.unflattened,
153
+ unflattened_dims=self.unflattened_dims,
140
154
  ),
141
155
  strides=outer_strides,
142
156
  source=self.source,
143
157
  source_dims=self.source_dims,
158
+ unflattened=self.unflattened,
159
+ unflattened_dims=self.unflattened_dims,
144
160
  )
145
161
 
146
162
  def expand(self, shape):
@@ -164,6 +180,8 @@ class Tensor:
164
180
  source=self.source,
165
181
  source_dims=self.source_dims,
166
182
  target_dims=self.target_dims,
183
+ unflattened=self.unflattened,
184
+ unflattened_dims=self.unflattened_dims,
167
185
  )
168
186
 
169
187
  def squeeze(self, dim):
@@ -192,6 +210,12 @@ class Tensor:
192
210
  for i, target_dim in enumerate(self.target_dims)
193
211
  if i not in dim
194
212
  ],
213
+ unflattened=self.unflattened,
214
+ unflattened_dims=[
215
+ unflattened_dim
216
+ for i, unflattened_dim in enumerate(self.unflattened_dims)
217
+ if i not in dim
218
+ ],
195
219
  )
196
220
 
197
221
  def permute(self, dims):
@@ -205,11 +229,13 @@ class Tensor:
205
229
  new_shape = [None for _ in range(self.ndim)]
206
230
  new_strides = [None for _ in range(self.ndim)]
207
231
  new_source_dims = [None for _ in range(self.ndim)]
232
+ new_unflattened_dims = [None for _ in range(self.ndim)]
208
233
 
209
234
  for original_dim, permuted_dim in enumerate(dims):
210
235
  new_shape[original_dim] = self.shape[permuted_dim]
211
236
  new_strides[original_dim] = self.strides[permuted_dim]
212
237
  new_source_dims[original_dim] = self.source_dims[permuted_dim]
238
+ new_unflattened_dims[original_dim] = self.unflattened_dims[permuted_dim]
213
239
 
214
240
  return type(self)(
215
241
  shape=new_shape,
@@ -218,6 +244,8 @@ class Tensor:
218
244
  source=self.source,
219
245
  source_dims=new_source_dims,
220
246
  target_dims=self.target_dims,
247
+ unflattened=self.unflattened,
248
+ unflattened_dims=new_unflattened_dims,
221
249
  )
222
250
 
223
251
  def flatten(self, start_dim=None, end_dim=None):
@@ -265,6 +293,16 @@ class Tensor:
265
293
  leading_target_dims + (flattening_target_dims[-1],) + trailing_target_dims
266
294
  )
267
295
 
296
+ leading_unflattened_dims = self.unflattened_dims[:start_dim]
297
+ flattening_unflattened_dims = self.unflattened_dims[start_dim:end_dim]
298
+ trailing_unflattened_dims = self.unflattened_dims[end_dim:]
299
+
300
+ new_unflattened_dims = (
301
+ leading_unflattened_dims
302
+ + (flattening_unflattened_dims,)
303
+ + trailing_unflattened_dims
304
+ )
305
+
268
306
  return type(self)(
269
307
  shape=new_shape,
270
308
  dtype=self.dtype,
@@ -272,6 +310,8 @@ class Tensor:
272
310
  source=self.source,
273
311
  source_dims=new_source_dims,
274
312
  target_dims=new_target_dims,
313
+ unflattened=self.unflattened,
314
+ unflattened_dims=new_unflattened_dims,
275
315
  )
276
316
 
277
317
  def ravel(self):
@@ -290,12 +330,14 @@ class Tensor:
290
330
  # TODO: Add error handling.
291
331
  new_shape = []
292
332
  new_strides = []
333
+ new_source_dims = []
293
334
 
294
335
  curr = self
295
336
 
296
337
  while isinstance(curr, type(self)):
297
338
  new_shape.extend(curr.shape)
298
339
  new_strides.extend(curr.strides)
340
+ new_source_dims.extend(curr.source_dims)
299
341
 
300
342
  curr = curr.dtype
301
343
 
@@ -304,6 +346,8 @@ class Tensor:
304
346
  strides=new_strides,
305
347
  other=self.source.other,
306
348
  name=self.source.name,
349
+ source=self.source,
350
+ source_dims=new_source_dims,
307
351
  )
308
352
 
309
353
  def names(self):
@@ -385,6 +429,14 @@ class Tensor:
385
429
  def target_dims(self, value):
386
430
  self._target_dims = tuple(value)
387
431
 
432
+ @property
433
+ def unflattened_dims(self):
434
+ return self._unflattened_dims
435
+
436
+ @unflattened_dims.setter
437
+ def unflattened_dims(self, value):
438
+ self._unflattened_dims = tuple(value)
439
+
388
440
  @staticmethod
389
441
  def pointer_pattern():
390
442
  return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
@@ -4,6 +4,12 @@ from mpl_toolkits.axes_grid1 import Divider, Size
4
4
 
5
5
 
6
6
  def visualize(tensor, color=None, save_path=None):
7
+ """Visualize a tensor as a structured grid representation.
8
+
9
+ :param tensor: The tensor to be visualized.
10
+ :param color: The color to be used for visualization.
11
+ :param save_path: The path where the visualization should be saved.
12
+ """
7
13
  outline_width = 0.1
8
14
  plt.rcParams["lines.linewidth"] = 72 * outline_width
9
15
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.12.0
3
+ Version: 0.13.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -11,6 +11,9 @@ Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
13
  Requires-Dist: triton>=3.0.0
14
+ Provides-Extra: all
15
+ Requires-Dist: matplotlib>=3.9.0; extra == 'all'
16
+ Requires-Dist: numpy>=2.1.0; extra == 'all'
14
17
  Provides-Extra: visualization
15
18
  Requires-Dist: matplotlib>=3.9.0; extra == 'visualization'
16
19
  Requires-Dist: numpy>=2.1.0; extra == 'visualization'
@@ -0,0 +1,12 @@
1
+ ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
+ ninetoothed/jit.py,sha256=fuMfb4gEvz78n-7--GS1Ud14g33SLT-vI1kqt9BPAIQ,25759
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
+ ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
6
+ ninetoothed/tensor.py,sha256=Q7WPigNyGOgP5lYYac39pF6zlFbCyXYrICPgGOuuyr4,13976
7
+ ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
+ ninetoothed/visualization.py,sha256=IZ7iTT4dl5_JFbO-WfSWPFWpgkyPr4nylwhSZVy8gss,3601
9
+ ninetoothed-0.13.0.dist-info/METADATA,sha256=Vp_huL1YVgNAa9q6C0J5hZnCR7SRyMFRALGD21ikrjc,7311
10
+ ninetoothed-0.13.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
11
+ ninetoothed-0.13.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
12
+ ninetoothed-0.13.0.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
- ninetoothed/jit.py,sha256=U3Nen5vyx69ulW7_hnRuATW86Ag9NgVgd3U02NVB20c,24430
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
- ninetoothed/symbol.py,sha256=mN96tp-2eUxbiNfxuxtKWNSxOSdYqlcmpY2MYQ-FiEg,4993
6
- ninetoothed/tensor.py,sha256=E63sq3jh7ZLiLwFYTtavztKEZx7kRX-UVa2ZXSP2X0s,12008
7
- ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
- ninetoothed/visualization.py,sha256=VPPh__Bral_Z9hKj9D4UOo8HvRFQidCWSe9cS-D5QfY,3351
9
- ninetoothed-0.12.0.dist-info/METADATA,sha256=gb5zeAxwYQRm-dFNmdX3uDIn_U-abEZ-VU9bSBBxioA,7198
10
- ninetoothed-0.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
11
- ninetoothed-0.12.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
12
- ninetoothed-0.12.0.dist-info/RECORD,,