ninetoothed 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +68 -30
- ninetoothed/symbol.py +13 -6
- ninetoothed/tensor.py +52 -0
- ninetoothed/visualization.py +6 -0
- {ninetoothed-0.12.0.dist-info → ninetoothed-0.13.0.dist-info}/METADATA +4 -1
- ninetoothed-0.13.0.dist-info/RECORD +12 -0
- ninetoothed-0.12.0.dist-info/RECORD +0 -12
- {ninetoothed-0.12.0.dist-info → ninetoothed-0.13.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.12.0.dist-info → ninetoothed-0.13.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -483,13 +483,15 @@ class CodeGenerator(ast.NodeTransformer):
|
|
483
483
|
mask = functools.reduce(
|
484
484
|
lambda x, y: x & y,
|
485
485
|
(
|
486
|
-
|
487
|
-
|
488
|
-
|
486
|
+
sum(
|
487
|
+
offsets[source_dim][target_dim][
|
488
|
+
type(self)._generate_slices(tensor, target_dim)
|
489
|
+
]
|
490
|
+
for target_dim in range(tensor.target.ndim)
|
491
|
+
if offsets[source_dim][target_dim] != 0
|
492
|
+
)
|
489
493
|
< tensor.source.shape[source_dim]
|
490
494
|
for source_dim in range(tensor.source.ndim)
|
491
|
-
for target_dim in range(tensor.target.ndim)
|
492
|
-
if offsets[source_dim][target_dim] != 0
|
493
495
|
),
|
494
496
|
) & functools.reduce(
|
495
497
|
lambda x, y: x & y,
|
@@ -505,13 +507,21 @@ class CodeGenerator(ast.NodeTransformer):
|
|
505
507
|
return pointers, mask
|
506
508
|
|
507
509
|
def _complete_indices(self, tensor, indices):
|
510
|
+
class _NextPowerOfTwoMaker(ast.NodeTransformer):
|
511
|
+
def visit_Name(self, node):
|
512
|
+
name = node.id
|
513
|
+
|
514
|
+
if not naming.is_meta(name):
|
515
|
+
next_power_of_2_name = naming.make_next_power_of_2(name)
|
516
|
+
|
517
|
+
return ast.Name(id=next_power_of_2_name, ctx=ast.Load())
|
518
|
+
|
519
|
+
return self.generic_visit(node)
|
520
|
+
|
508
521
|
indices = list(self._generate_pid_indices(tensor) + tuple(indices))
|
509
522
|
|
510
523
|
for size in tensor.innermost().shape:
|
511
|
-
|
512
|
-
name = size.node.id
|
513
|
-
if not naming.is_meta(name):
|
514
|
-
size = naming.make_next_power_of_2(name)
|
524
|
+
size = _NextPowerOfTwoMaker().visit(Symbol(copy.deepcopy(size)).node)
|
515
525
|
|
516
526
|
indices.append(call("arange", 0, size))
|
517
527
|
|
@@ -549,8 +559,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
549
559
|
|
550
560
|
@staticmethod
|
551
561
|
def _generate_offsets(tensor, indices):
|
552
|
-
|
553
|
-
lambda: collections.defaultdict(
|
562
|
+
raw_offsets = collections.defaultdict(
|
563
|
+
lambda: collections.defaultdict(
|
564
|
+
lambda: collections.defaultdict(lambda: Symbol(0))
|
565
|
+
)
|
554
566
|
)
|
555
567
|
|
556
568
|
curr = tensor
|
@@ -560,36 +572,62 @@ class CodeGenerator(ast.NodeTransformer):
|
|
560
572
|
stop = start + curr.ndim
|
561
573
|
curr_indices = indices[start:stop]
|
562
574
|
|
563
|
-
for index, stride, source_dim, target_dim in zip(
|
564
|
-
curr_indices,
|
575
|
+
for index, stride, source_dim, target_dim, unflattened_dim in zip(
|
576
|
+
curr_indices,
|
577
|
+
curr.strides,
|
578
|
+
curr.source_dims,
|
579
|
+
curr.target_dims,
|
580
|
+
curr.unflattened_dims,
|
565
581
|
):
|
566
|
-
|
582
|
+
raw_offsets[source_dim][target_dim][unflattened_dim] += index * stride
|
567
583
|
|
568
584
|
start = stop
|
569
585
|
curr = curr.dtype
|
570
586
|
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
continue
|
587
|
+
offsets = collections.defaultdict(
|
588
|
+
lambda: collections.defaultdict(lambda: Symbol(0))
|
589
|
+
)
|
575
590
|
|
576
|
-
|
577
|
-
offsets[source_dim][target_dim],
|
578
|
-
tuple(tensor.source.shape[dim] for dim in source_dim),
|
579
|
-
)
|
591
|
+
source_strides = tuple(Symbol(stride) for stride in tensor.source.strides)
|
580
592
|
|
581
|
-
|
582
|
-
|
593
|
+
unflattened_strides = tuple(
|
594
|
+
Symbol(stride) for stride in tensor.unflattened.strides
|
595
|
+
)
|
583
596
|
|
584
|
-
|
585
|
-
|
586
|
-
offsets[source_dim][target_dim]
|
587
|
-
|
597
|
+
def _add_unraveled_offsets(raw_offs, source_dim, target_dim, unflattened_dim):
|
598
|
+
if not isinstance(unflattened_dim, tuple):
|
599
|
+
offsets[source_dim][target_dim] += copy.deepcopy(
|
600
|
+
raw_offs
|
601
|
+
).find_and_replace(
|
602
|
+
unflattened_strides, Symbol(1)
|
603
|
+
) * unflattened_strides[unflattened_dim].find_and_replace(
|
604
|
+
source_strides, Symbol(1)
|
588
605
|
)
|
589
|
-
|
590
|
-
|
606
|
+
|
607
|
+
return
|
608
|
+
|
609
|
+
unraveled_offs = CodeGenerator._unravel_index(
|
610
|
+
raw_offs,
|
611
|
+
tuple(tensor.unflattened.shape[dim] for dim in unflattened_dim),
|
612
|
+
)
|
613
|
+
|
614
|
+
for raw_offs, source_dim, unflattened_dim in zip(
|
615
|
+
unraveled_offs, source_dim, unflattened_dim
|
616
|
+
):
|
617
|
+
_add_unraveled_offsets(
|
618
|
+
raw_offs, source_dim, target_dim, unflattened_dim
|
591
619
|
)
|
592
620
|
|
621
|
+
for source_dim in tuple(raw_offsets):
|
622
|
+
for target_dim in tuple(raw_offsets[source_dim]):
|
623
|
+
for unflattened_dim in tuple(raw_offsets[source_dim][target_dim]):
|
624
|
+
_add_unraveled_offsets(
|
625
|
+
raw_offsets[source_dim][target_dim][unflattened_dim],
|
626
|
+
source_dim,
|
627
|
+
target_dim,
|
628
|
+
unflattened_dim,
|
629
|
+
)
|
630
|
+
|
593
631
|
return offsets
|
594
632
|
|
595
633
|
@staticmethod
|
ninetoothed/symbol.py
CHANGED
@@ -140,7 +140,12 @@ class Symbol:
|
|
140
140
|
return ast.unparse(self._node)
|
141
141
|
|
142
142
|
def find_and_replace(self, target, replacement):
|
143
|
-
|
143
|
+
if isinstance(target, tuple):
|
144
|
+
targets = tuple(item.node for item in target)
|
145
|
+
else:
|
146
|
+
targets = (target.node,)
|
147
|
+
|
148
|
+
return Symbol(_FindAndReplacer(targets, replacement.node).visit(self._node))
|
144
149
|
|
145
150
|
def names(self):
|
146
151
|
class NameCollector(ast.NodeVisitor):
|
@@ -175,12 +180,14 @@ class Symbol:
|
|
175
180
|
|
176
181
|
|
177
182
|
class _FindAndReplacer(ast.NodeTransformer):
|
178
|
-
def __init__(self,
|
179
|
-
self.
|
183
|
+
def __init__(self, targets, replacement):
|
184
|
+
self._targets_unparsed = tuple(
|
185
|
+
sorted({ast.unparse(target) for target in targets}, key=len, reverse=True)
|
186
|
+
)
|
180
187
|
self._replacement = replacement
|
181
188
|
|
182
|
-
def
|
183
|
-
if node
|
189
|
+
def visit(self, node):
|
190
|
+
if ast.unparse(node) in self._targets_unparsed:
|
184
191
|
return self._replacement
|
185
192
|
|
186
|
-
return
|
193
|
+
return super().visit(node)
|
ninetoothed/tensor.py
CHANGED
@@ -37,6 +37,8 @@ class Tensor:
|
|
37
37
|
source_dims=None,
|
38
38
|
target=None,
|
39
39
|
target_dims=None,
|
40
|
+
unflattened=None,
|
41
|
+
unflattened_dims=None,
|
40
42
|
):
|
41
43
|
self.dtype = dtype
|
42
44
|
|
@@ -81,6 +83,16 @@ class Tensor:
|
|
81
83
|
else:
|
82
84
|
self.target_dims = (dim for dim in range(self.target.ndim))
|
83
85
|
|
86
|
+
if unflattened is not None:
|
87
|
+
self.unflattened = unflattened
|
88
|
+
else:
|
89
|
+
self.unflattened = self
|
90
|
+
|
91
|
+
if unflattened_dims is not None:
|
92
|
+
self.unflattened_dims = unflattened_dims
|
93
|
+
else:
|
94
|
+
self.unflattened_dims = (dim for dim in range(self.unflattened.ndim))
|
95
|
+
|
84
96
|
type(self).num_instances += 1
|
85
97
|
|
86
98
|
def tile(self, tile_shape, strides=None, dilation=None):
|
@@ -137,10 +149,14 @@ class Tensor:
|
|
137
149
|
strides=inner_strides,
|
138
150
|
source=self.source,
|
139
151
|
source_dims=self.source_dims,
|
152
|
+
unflattened=self.unflattened,
|
153
|
+
unflattened_dims=self.unflattened_dims,
|
140
154
|
),
|
141
155
|
strides=outer_strides,
|
142
156
|
source=self.source,
|
143
157
|
source_dims=self.source_dims,
|
158
|
+
unflattened=self.unflattened,
|
159
|
+
unflattened_dims=self.unflattened_dims,
|
144
160
|
)
|
145
161
|
|
146
162
|
def expand(self, shape):
|
@@ -164,6 +180,8 @@ class Tensor:
|
|
164
180
|
source=self.source,
|
165
181
|
source_dims=self.source_dims,
|
166
182
|
target_dims=self.target_dims,
|
183
|
+
unflattened=self.unflattened,
|
184
|
+
unflattened_dims=self.unflattened_dims,
|
167
185
|
)
|
168
186
|
|
169
187
|
def squeeze(self, dim):
|
@@ -192,6 +210,12 @@ class Tensor:
|
|
192
210
|
for i, target_dim in enumerate(self.target_dims)
|
193
211
|
if i not in dim
|
194
212
|
],
|
213
|
+
unflattened=self.unflattened,
|
214
|
+
unflattened_dims=[
|
215
|
+
unflattened_dim
|
216
|
+
for i, unflattened_dim in enumerate(self.unflattened_dims)
|
217
|
+
if i not in dim
|
218
|
+
],
|
195
219
|
)
|
196
220
|
|
197
221
|
def permute(self, dims):
|
@@ -205,11 +229,13 @@ class Tensor:
|
|
205
229
|
new_shape = [None for _ in range(self.ndim)]
|
206
230
|
new_strides = [None for _ in range(self.ndim)]
|
207
231
|
new_source_dims = [None for _ in range(self.ndim)]
|
232
|
+
new_unflattened_dims = [None for _ in range(self.ndim)]
|
208
233
|
|
209
234
|
for original_dim, permuted_dim in enumerate(dims):
|
210
235
|
new_shape[original_dim] = self.shape[permuted_dim]
|
211
236
|
new_strides[original_dim] = self.strides[permuted_dim]
|
212
237
|
new_source_dims[original_dim] = self.source_dims[permuted_dim]
|
238
|
+
new_unflattened_dims[original_dim] = self.unflattened_dims[permuted_dim]
|
213
239
|
|
214
240
|
return type(self)(
|
215
241
|
shape=new_shape,
|
@@ -218,6 +244,8 @@ class Tensor:
|
|
218
244
|
source=self.source,
|
219
245
|
source_dims=new_source_dims,
|
220
246
|
target_dims=self.target_dims,
|
247
|
+
unflattened=self.unflattened,
|
248
|
+
unflattened_dims=new_unflattened_dims,
|
221
249
|
)
|
222
250
|
|
223
251
|
def flatten(self, start_dim=None, end_dim=None):
|
@@ -265,6 +293,16 @@ class Tensor:
|
|
265
293
|
leading_target_dims + (flattening_target_dims[-1],) + trailing_target_dims
|
266
294
|
)
|
267
295
|
|
296
|
+
leading_unflattened_dims = self.unflattened_dims[:start_dim]
|
297
|
+
flattening_unflattened_dims = self.unflattened_dims[start_dim:end_dim]
|
298
|
+
trailing_unflattened_dims = self.unflattened_dims[end_dim:]
|
299
|
+
|
300
|
+
new_unflattened_dims = (
|
301
|
+
leading_unflattened_dims
|
302
|
+
+ (flattening_unflattened_dims,)
|
303
|
+
+ trailing_unflattened_dims
|
304
|
+
)
|
305
|
+
|
268
306
|
return type(self)(
|
269
307
|
shape=new_shape,
|
270
308
|
dtype=self.dtype,
|
@@ -272,6 +310,8 @@ class Tensor:
|
|
272
310
|
source=self.source,
|
273
311
|
source_dims=new_source_dims,
|
274
312
|
target_dims=new_target_dims,
|
313
|
+
unflattened=self.unflattened,
|
314
|
+
unflattened_dims=new_unflattened_dims,
|
275
315
|
)
|
276
316
|
|
277
317
|
def ravel(self):
|
@@ -290,12 +330,14 @@ class Tensor:
|
|
290
330
|
# TODO: Add error handling.
|
291
331
|
new_shape = []
|
292
332
|
new_strides = []
|
333
|
+
new_source_dims = []
|
293
334
|
|
294
335
|
curr = self
|
295
336
|
|
296
337
|
while isinstance(curr, type(self)):
|
297
338
|
new_shape.extend(curr.shape)
|
298
339
|
new_strides.extend(curr.strides)
|
340
|
+
new_source_dims.extend(curr.source_dims)
|
299
341
|
|
300
342
|
curr = curr.dtype
|
301
343
|
|
@@ -304,6 +346,8 @@ class Tensor:
|
|
304
346
|
strides=new_strides,
|
305
347
|
other=self.source.other,
|
306
348
|
name=self.source.name,
|
349
|
+
source=self.source,
|
350
|
+
source_dims=new_source_dims,
|
307
351
|
)
|
308
352
|
|
309
353
|
def names(self):
|
@@ -385,6 +429,14 @@ class Tensor:
|
|
385
429
|
def target_dims(self, value):
|
386
430
|
self._target_dims = tuple(value)
|
387
431
|
|
432
|
+
@property
|
433
|
+
def unflattened_dims(self):
|
434
|
+
return self._unflattened_dims
|
435
|
+
|
436
|
+
@unflattened_dims.setter
|
437
|
+
def unflattened_dims(self, value):
|
438
|
+
self._unflattened_dims = tuple(value)
|
439
|
+
|
388
440
|
@staticmethod
|
389
441
|
def pointer_pattern():
|
390
442
|
return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
|
ninetoothed/visualization.py
CHANGED
@@ -4,6 +4,12 @@ from mpl_toolkits.axes_grid1 import Divider, Size
|
|
4
4
|
|
5
5
|
|
6
6
|
def visualize(tensor, color=None, save_path=None):
|
7
|
+
"""Visualize a tensor as a structured grid representation.
|
8
|
+
|
9
|
+
:param tensor: The tensor to be visualized.
|
10
|
+
:param color: The color to be used for visualization.
|
11
|
+
:param save_path: The path where the visualization should be saved.
|
12
|
+
"""
|
7
13
|
outline_width = 0.1
|
8
14
|
plt.rcParams["lines.linewidth"] = 72 * outline_width
|
9
15
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.13.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -11,6 +11,9 @@ Classifier: Operating System :: OS Independent
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
13
|
Requires-Dist: triton>=3.0.0
|
14
|
+
Provides-Extra: all
|
15
|
+
Requires-Dist: matplotlib>=3.9.0; extra == 'all'
|
16
|
+
Requires-Dist: numpy>=2.1.0; extra == 'all'
|
14
17
|
Provides-Extra: visualization
|
15
18
|
Requires-Dist: matplotlib>=3.9.0; extra == 'visualization'
|
16
19
|
Requires-Dist: numpy>=2.1.0; extra == 'visualization'
|
@@ -0,0 +1,12 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
+
ninetoothed/jit.py,sha256=fuMfb4gEvz78n-7--GS1Ud14g33SLT-vI1kqt9BPAIQ,25759
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
+
ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
|
6
|
+
ninetoothed/tensor.py,sha256=Q7WPigNyGOgP5lYYac39pF6zlFbCyXYrICPgGOuuyr4,13976
|
7
|
+
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
+
ninetoothed/visualization.py,sha256=IZ7iTT4dl5_JFbO-WfSWPFWpgkyPr4nylwhSZVy8gss,3601
|
9
|
+
ninetoothed-0.13.0.dist-info/METADATA,sha256=Vp_huL1YVgNAa9q6C0J5hZnCR7SRyMFRALGD21ikrjc,7311
|
10
|
+
ninetoothed-0.13.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
11
|
+
ninetoothed-0.13.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
12
|
+
ninetoothed-0.13.0.dist-info/RECORD,,
|
@@ -1,12 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
-
ninetoothed/jit.py,sha256=U3Nen5vyx69ulW7_hnRuATW86Ag9NgVgd3U02NVB20c,24430
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
-
ninetoothed/symbol.py,sha256=mN96tp-2eUxbiNfxuxtKWNSxOSdYqlcmpY2MYQ-FiEg,4993
|
6
|
-
ninetoothed/tensor.py,sha256=E63sq3jh7ZLiLwFYTtavztKEZx7kRX-UVa2ZXSP2X0s,12008
|
7
|
-
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
-
ninetoothed/visualization.py,sha256=VPPh__Bral_Z9hKj9D4UOo8HvRFQidCWSe9cS-D5QfY,3351
|
9
|
-
ninetoothed-0.12.0.dist-info/METADATA,sha256=gb5zeAxwYQRm-dFNmdX3uDIn_U-abEZ-VU9bSBBxioA,7198
|
10
|
-
ninetoothed-0.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
11
|
-
ninetoothed-0.12.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
12
|
-
ninetoothed-0.12.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|