ninetoothed 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -39,23 +39,12 @@ def jit(_func=None, *, _prettify=False):
39
39
 
40
40
 
41
41
  class JIT:
42
- handles = collections.defaultdict(dict)
43
-
44
42
  def __init__(self, func, _prettify=False):
45
43
  self.func = func
46
44
 
47
45
  self._prettify = _prettify
48
46
 
49
47
  def __call__(self):
50
- source_file = inspect.getsourcefile(self.func)
51
- source_line = inspect.getsourcelines(self.func)[1]
52
-
53
- if (
54
- source_file in type(self).handles
55
- and source_line in type(self).handles[source_file]
56
- ):
57
- return type(self).handles[source_file][source_line]
58
-
59
48
  tree = self._get_tree()
60
49
 
61
50
  CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
@@ -93,8 +82,6 @@ class JIT:
93
82
  source,
94
83
  )
95
84
 
96
- type(self).handles[source_file][source_line] = handle
97
-
98
85
  return handle
99
86
 
100
87
  def _get_tree(self):
@@ -350,6 +337,7 @@ class CodeGenerator(ast.NodeTransformer):
350
337
  + [
351
338
  ast.arg(arg=param)
352
339
  for param in non_next_power_of_2_constexpr_params_without_prefixes
340
+ if not Tensor.size_pattern().fullmatch(param)
353
341
  ],
354
342
  kwonlyargs=[],
355
343
  defaults=[],
@@ -489,9 +477,9 @@ class CodeGenerator(ast.NodeTransformer):
489
477
  return pointers, mask
490
478
 
491
479
  def _complete_indices(self, tensor, indices):
492
- indices = list(self._generate_pid_indices(tensor) + indices)
480
+ indices = list(self._generate_pid_indices(tensor) + tuple(indices))
493
481
 
494
- for size in tensor.inmost().shape:
482
+ for size in tensor.innermost().shape:
495
483
  if Symbol.is_name(size):
496
484
  name = size.node.id
497
485
  if not naming.is_meta(name):
@@ -526,7 +514,10 @@ class CodeGenerator(ast.NodeTransformer):
526
514
 
527
515
  @staticmethod
528
516
  def _generate_slices(tensor, dim):
529
- return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
517
+ return tuple(
518
+ slice(None) if target_dim == dim else None
519
+ for target_dim in tensor.innermost().target_dims
520
+ )
530
521
 
531
522
  @staticmethod
532
523
  def _generate_offsets(tensor, indices):
ninetoothed/naming.py CHANGED
@@ -1,6 +1,10 @@
1
1
  import re
2
2
 
3
3
 
4
+ def auto_generate(name):
5
+ return f"ninetoothed_{name}"
6
+
7
+
4
8
  def make_constexpr(name):
5
9
  return _add_prefix(name, _CONSTEXPR)
6
10
 
ninetoothed/tensor.py CHANGED
@@ -2,6 +2,7 @@ import itertools
2
2
  import math
3
3
  import re
4
4
 
5
+ import ninetoothed.naming as naming
5
6
  from ninetoothed.language import call
6
7
  from ninetoothed.symbol import Symbol
7
8
 
@@ -16,6 +17,7 @@ class Tensor:
16
17
  dtype=None,
17
18
  strides=None,
18
19
  other=None,
20
+ constexpr_shape=None,
19
21
  name=None,
20
22
  source=None,
21
23
  source_dims=None,
@@ -27,10 +29,13 @@ class Tensor:
27
29
  if name is not None:
28
30
  self.name = name
29
31
  else:
30
- self.name = f"_ninetoothed_tensor_{type(self).num_instances}"
32
+ self.name = naming.auto_generate(f"tensor_{type(self).num_instances}")
31
33
 
32
34
  if ndim is not None:
33
- self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
35
+ self.shape = (
36
+ Symbol(self.size_string(i), constexpr=constexpr_shape)
37
+ for i in range(ndim)
38
+ )
34
39
  self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
35
40
  else:
36
41
  self.shape = shape
@@ -107,14 +112,10 @@ class Tensor:
107
112
  strides=inner_strides,
108
113
  source=self.source,
109
114
  source_dims=self.source_dims,
110
- target=self.target,
111
- target_dims=self.target_dims,
112
115
  ),
113
116
  strides=outer_strides,
114
117
  source=self.source,
115
118
  source_dims=self.source_dims,
116
- target=self.target,
117
- target_dims=self.target_dims,
118
119
  )
119
120
 
120
121
  def expand(self, shape):
@@ -131,23 +132,22 @@ class Tensor:
131
132
  ],
132
133
  source=self.source,
133
134
  source_dims=self.source_dims,
134
- target=self.target,
135
- target_dims=self.target_dims,
136
135
  )
137
136
 
138
137
  def squeeze(self, dim):
138
+ if not isinstance(dim, tuple):
139
+ dim = (dim,)
140
+
139
141
  # TODO: Add error handling.
140
142
  return type(self)(
141
- shape=[size for i, size in enumerate(self.shape) if dim != i],
143
+ shape=[size for i, size in enumerate(self.shape) if i not in dim],
142
144
  dtype=self.dtype,
143
- strides=[stride for i, stride in enumerate(self.strides) if dim != i],
145
+ strides=[stride for i, stride in enumerate(self.strides) if i not in dim],
144
146
  source=self.source,
145
147
  source_dims=[
146
- source_dim for i, source_dim in enumerate(self.source_dims) if dim != i
147
- ],
148
- target=self.target,
149
- target_dims=[
150
- target_dim for i, target_dim in enumerate(self.target_dims) if dim != i
148
+ source_dim
149
+ for i, source_dim in enumerate(self.source_dims)
150
+ if i not in dim
151
151
  ],
152
152
  )
153
153
 
@@ -168,8 +168,6 @@ class Tensor:
168
168
  strides=new_strides,
169
169
  source=self.source,
170
170
  source_dims=new_source_dims,
171
- target=self.target,
172
- target_dims=self.target_dims,
173
171
  )
174
172
 
175
173
  def flatten(self, start_dim=None, end_dim=None):
@@ -205,8 +203,6 @@ class Tensor:
205
203
  strides=new_strides,
206
204
  source=self.source,
207
205
  source_dims=new_source_dims,
208
- target=self.target,
209
- target_dims=self.target_dims,
210
206
  )
211
207
 
212
208
  def ravel(self):
@@ -245,11 +241,11 @@ class Tensor:
245
241
  | (self.source.names() if self.source is not self else set())
246
242
  )
247
243
 
248
- def inmost(self):
244
+ def innermost(self):
249
245
  if not isinstance(self.dtype, type(self)):
250
246
  return self
251
247
 
252
- return self.dtype.inmost()
248
+ return self.dtype.innermost()
253
249
 
254
250
  def pointer_string(self):
255
251
  return f"{self.name}_pointer"
ninetoothed/torchifier.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import ast
2
2
 
3
+ import ninetoothed.naming as naming
3
4
  from ninetoothed.tensor import Tensor
4
5
 
5
6
 
@@ -9,6 +10,9 @@ class Torchifier(ast.NodeTransformer):
9
10
 
10
11
  source = node.id
11
12
 
13
+ if naming.is_constexpr(source):
14
+ return node
15
+
12
16
  def repl(match):
13
17
  return f"{match.group(1)}"
14
18
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.8.0
3
+ Version: 0.10.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -0,0 +1,11 @@
1
+ ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
+ ninetoothed/jit.py,sha256=0LeDBpSYFgPx4hatP_ZsvElsj0d9d552OKRc__L1Jvc,23460
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
+ ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
+ ninetoothed/tensor.py,sha256=LGS9wYmPKckZjIEXsMHdclTmVhzkfrc3avOPiCQY1tU,9153
7
+ ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
+ ninetoothed-0.10.0.dist-info/METADATA,sha256=nQWkQ--AceNN3DoD-kP9_aykZYVw8LOfqf2iO63v1Ek,7055
9
+ ninetoothed-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ ninetoothed-0.10.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
+ ninetoothed-0.10.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
- ninetoothed/jit.py,sha256=z70hQEsogfQu0cLxq5m3cOsWsVANcMRJaVv5di9vk1c,23741
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
5
- ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
- ninetoothed/tensor.py,sha256=_jM0tVgqIwZd3MJJsGVTaLCsSxpPO8JfF4qkMShhQvQ,9429
7
- ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
8
- ninetoothed-0.8.0.dist-info/METADATA,sha256=gPWYhTBH5EdeOyGnArZIEw82aFmoQchD6pxtLi6LGMA,7054
9
- ninetoothed-0.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- ninetoothed-0.8.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
- ninetoothed-0.8.0.dist-info/RECORD,,