ninetoothed 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +7 -16
- ninetoothed/naming.py +4 -0
- ninetoothed/tensor.py +17 -21
- ninetoothed/torchifier.py +4 -0
- {ninetoothed-0.8.0.dist-info → ninetoothed-0.10.0.dist-info}/METADATA +1 -1
- ninetoothed-0.10.0.dist-info/RECORD +11 -0
- ninetoothed-0.8.0.dist-info/RECORD +0 -11
- {ninetoothed-0.8.0.dist-info → ninetoothed-0.10.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.8.0.dist-info → ninetoothed-0.10.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -39,23 +39,12 @@ def jit(_func=None, *, _prettify=False):
|
|
39
39
|
|
40
40
|
|
41
41
|
class JIT:
|
42
|
-
handles = collections.defaultdict(dict)
|
43
|
-
|
44
42
|
def __init__(self, func, _prettify=False):
|
45
43
|
self.func = func
|
46
44
|
|
47
45
|
self._prettify = _prettify
|
48
46
|
|
49
47
|
def __call__(self):
|
50
|
-
source_file = inspect.getsourcefile(self.func)
|
51
|
-
source_line = inspect.getsourcelines(self.func)[1]
|
52
|
-
|
53
|
-
if (
|
54
|
-
source_file in type(self).handles
|
55
|
-
and source_line in type(self).handles[source_file]
|
56
|
-
):
|
57
|
-
return type(self).handles[source_file][source_line]
|
58
|
-
|
59
48
|
tree = self._get_tree()
|
60
49
|
|
61
50
|
CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
|
@@ -93,8 +82,6 @@ class JIT:
|
|
93
82
|
source,
|
94
83
|
)
|
95
84
|
|
96
|
-
type(self).handles[source_file][source_line] = handle
|
97
|
-
|
98
85
|
return handle
|
99
86
|
|
100
87
|
def _get_tree(self):
|
@@ -350,6 +337,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
350
337
|
+ [
|
351
338
|
ast.arg(arg=param)
|
352
339
|
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
340
|
+
if not Tensor.size_pattern().fullmatch(param)
|
353
341
|
],
|
354
342
|
kwonlyargs=[],
|
355
343
|
defaults=[],
|
@@ -489,9 +477,9 @@ class CodeGenerator(ast.NodeTransformer):
|
|
489
477
|
return pointers, mask
|
490
478
|
|
491
479
|
def _complete_indices(self, tensor, indices):
|
492
|
-
indices = list(self._generate_pid_indices(tensor) + indices)
|
480
|
+
indices = list(self._generate_pid_indices(tensor) + tuple(indices))
|
493
481
|
|
494
|
-
for size in tensor.
|
482
|
+
for size in tensor.innermost().shape:
|
495
483
|
if Symbol.is_name(size):
|
496
484
|
name = size.node.id
|
497
485
|
if not naming.is_meta(name):
|
@@ -526,7 +514,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
526
514
|
|
527
515
|
@staticmethod
|
528
516
|
def _generate_slices(tensor, dim):
|
529
|
-
return tuple(
|
517
|
+
return tuple(
|
518
|
+
slice(None) if target_dim == dim else None
|
519
|
+
for target_dim in tensor.innermost().target_dims
|
520
|
+
)
|
530
521
|
|
531
522
|
@staticmethod
|
532
523
|
def _generate_offsets(tensor, indices):
|
ninetoothed/naming.py
CHANGED
ninetoothed/tensor.py
CHANGED
@@ -2,6 +2,7 @@ import itertools
|
|
2
2
|
import math
|
3
3
|
import re
|
4
4
|
|
5
|
+
import ninetoothed.naming as naming
|
5
6
|
from ninetoothed.language import call
|
6
7
|
from ninetoothed.symbol import Symbol
|
7
8
|
|
@@ -16,6 +17,7 @@ class Tensor:
|
|
16
17
|
dtype=None,
|
17
18
|
strides=None,
|
18
19
|
other=None,
|
20
|
+
constexpr_shape=None,
|
19
21
|
name=None,
|
20
22
|
source=None,
|
21
23
|
source_dims=None,
|
@@ -27,10 +29,13 @@ class Tensor:
|
|
27
29
|
if name is not None:
|
28
30
|
self.name = name
|
29
31
|
else:
|
30
|
-
self.name = f"
|
32
|
+
self.name = naming.auto_generate(f"tensor_{type(self).num_instances}")
|
31
33
|
|
32
34
|
if ndim is not None:
|
33
|
-
self.shape = (
|
35
|
+
self.shape = (
|
36
|
+
Symbol(self.size_string(i), constexpr=constexpr_shape)
|
37
|
+
for i in range(ndim)
|
38
|
+
)
|
34
39
|
self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
|
35
40
|
else:
|
36
41
|
self.shape = shape
|
@@ -107,14 +112,10 @@ class Tensor:
|
|
107
112
|
strides=inner_strides,
|
108
113
|
source=self.source,
|
109
114
|
source_dims=self.source_dims,
|
110
|
-
target=self.target,
|
111
|
-
target_dims=self.target_dims,
|
112
115
|
),
|
113
116
|
strides=outer_strides,
|
114
117
|
source=self.source,
|
115
118
|
source_dims=self.source_dims,
|
116
|
-
target=self.target,
|
117
|
-
target_dims=self.target_dims,
|
118
119
|
)
|
119
120
|
|
120
121
|
def expand(self, shape):
|
@@ -131,23 +132,22 @@ class Tensor:
|
|
131
132
|
],
|
132
133
|
source=self.source,
|
133
134
|
source_dims=self.source_dims,
|
134
|
-
target=self.target,
|
135
|
-
target_dims=self.target_dims,
|
136
135
|
)
|
137
136
|
|
138
137
|
def squeeze(self, dim):
|
138
|
+
if not isinstance(dim, tuple):
|
139
|
+
dim = (dim,)
|
140
|
+
|
139
141
|
# TODO: Add error handling.
|
140
142
|
return type(self)(
|
141
|
-
shape=[size for i, size in enumerate(self.shape) if
|
143
|
+
shape=[size for i, size in enumerate(self.shape) if i not in dim],
|
142
144
|
dtype=self.dtype,
|
143
|
-
strides=[stride for i, stride in enumerate(self.strides) if
|
145
|
+
strides=[stride for i, stride in enumerate(self.strides) if i not in dim],
|
144
146
|
source=self.source,
|
145
147
|
source_dims=[
|
146
|
-
source_dim
|
147
|
-
|
148
|
-
|
149
|
-
target_dims=[
|
150
|
-
target_dim for i, target_dim in enumerate(self.target_dims) if dim != i
|
148
|
+
source_dim
|
149
|
+
for i, source_dim in enumerate(self.source_dims)
|
150
|
+
if i not in dim
|
151
151
|
],
|
152
152
|
)
|
153
153
|
|
@@ -168,8 +168,6 @@ class Tensor:
|
|
168
168
|
strides=new_strides,
|
169
169
|
source=self.source,
|
170
170
|
source_dims=new_source_dims,
|
171
|
-
target=self.target,
|
172
|
-
target_dims=self.target_dims,
|
173
171
|
)
|
174
172
|
|
175
173
|
def flatten(self, start_dim=None, end_dim=None):
|
@@ -205,8 +203,6 @@ class Tensor:
|
|
205
203
|
strides=new_strides,
|
206
204
|
source=self.source,
|
207
205
|
source_dims=new_source_dims,
|
208
|
-
target=self.target,
|
209
|
-
target_dims=self.target_dims,
|
210
206
|
)
|
211
207
|
|
212
208
|
def ravel(self):
|
@@ -245,11 +241,11 @@ class Tensor:
|
|
245
241
|
| (self.source.names() if self.source is not self else set())
|
246
242
|
)
|
247
243
|
|
248
|
-
def
|
244
|
+
def innermost(self):
|
249
245
|
if not isinstance(self.dtype, type(self)):
|
250
246
|
return self
|
251
247
|
|
252
|
-
return self.dtype.
|
248
|
+
return self.dtype.innermost()
|
253
249
|
|
254
250
|
def pointer_string(self):
|
255
251
|
return f"{self.name}_pointer"
|
ninetoothed/torchifier.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import ast
|
2
2
|
|
3
|
+
import ninetoothed.naming as naming
|
3
4
|
from ninetoothed.tensor import Tensor
|
4
5
|
|
5
6
|
|
@@ -9,6 +10,9 @@ class Torchifier(ast.NodeTransformer):
|
|
9
10
|
|
10
11
|
source = node.id
|
11
12
|
|
13
|
+
if naming.is_constexpr(source):
|
14
|
+
return node
|
15
|
+
|
12
16
|
def repl(match):
|
13
17
|
return f"{match.group(1)}"
|
14
18
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -0,0 +1,11 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
+
ninetoothed/jit.py,sha256=0LeDBpSYFgPx4hatP_ZsvElsj0d9d552OKRc__L1Jvc,23460
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
+
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
+
ninetoothed/tensor.py,sha256=LGS9wYmPKckZjIEXsMHdclTmVhzkfrc3avOPiCQY1tU,9153
|
7
|
+
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
+
ninetoothed-0.10.0.dist-info/METADATA,sha256=nQWkQ--AceNN3DoD-kP9_aykZYVw8LOfqf2iO63v1Ek,7055
|
9
|
+
ninetoothed-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
+
ninetoothed-0.10.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
+
ninetoothed-0.10.0.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
-
ninetoothed/jit.py,sha256=z70hQEsogfQu0cLxq5m3cOsWsVANcMRJaVv5di9vk1c,23741
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
|
5
|
-
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
-
ninetoothed/tensor.py,sha256=_jM0tVgqIwZd3MJJsGVTaLCsSxpPO8JfF4qkMShhQvQ,9429
|
7
|
-
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
8
|
-
ninetoothed-0.8.0.dist-info/METADATA,sha256=gPWYhTBH5EdeOyGnArZIEw82aFmoQchD6pxtLi6LGMA,7054
|
9
|
-
ninetoothed-0.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
-
ninetoothed-0.8.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
-
ninetoothed-0.8.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|