ninetoothed 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -39,23 +39,12 @@ def jit(_func=None, *, _prettify=False):
39
39
 
40
40
 
41
41
  class JIT:
42
- handles = collections.defaultdict(dict)
43
-
44
42
  def __init__(self, func, _prettify=False):
45
43
  self.func = func
46
44
 
47
45
  self._prettify = _prettify
48
46
 
49
47
  def __call__(self):
50
- source_file = inspect.getsourcefile(self.func)
51
- source_line = inspect.getsourcelines(self.func)[1]
52
-
53
- if (
54
- source_file in type(self).handles
55
- and source_line in type(self).handles[source_file]
56
- ):
57
- return type(self).handles[source_file][source_line]
58
-
59
48
  tree = self._get_tree()
60
49
 
61
50
  CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
@@ -93,8 +82,6 @@ class JIT:
93
82
  source,
94
83
  )
95
84
 
96
- type(self).handles[source_file][source_line] = handle
97
-
98
85
  return handle
99
86
 
100
87
  def _get_tree(self):
@@ -350,6 +337,7 @@ class CodeGenerator(ast.NodeTransformer):
350
337
  + [
351
338
  ast.arg(arg=param)
352
339
  for param in non_next_power_of_2_constexpr_params_without_prefixes
340
+ if not Tensor.size_pattern().fullmatch(param)
353
341
  ],
354
342
  kwonlyargs=[],
355
343
  defaults=[],
@@ -489,7 +477,7 @@ class CodeGenerator(ast.NodeTransformer):
489
477
  return pointers, mask
490
478
 
491
479
  def _complete_indices(self, tensor, indices):
492
- indices = list(self._generate_pid_indices(tensor) + indices)
480
+ indices = list(self._generate_pid_indices(tensor) + tuple(indices))
493
481
 
494
482
  for size in tensor.inmost().shape:
495
483
  if Symbol.is_name(size):
ninetoothed/naming.py CHANGED
@@ -1,6 +1,10 @@
1
1
  import re
2
2
 
3
3
 
4
+ def auto_generate(name):
5
+ return f"ninetoothed_{name}"
6
+
7
+
4
8
  def make_constexpr(name):
5
9
  return _add_prefix(name, _CONSTEXPR)
6
10
 
ninetoothed/tensor.py CHANGED
@@ -2,6 +2,7 @@ import itertools
2
2
  import math
3
3
  import re
4
4
 
5
+ import ninetoothed.naming as naming
5
6
  from ninetoothed.language import call
6
7
  from ninetoothed.symbol import Symbol
7
8
 
@@ -16,6 +17,7 @@ class Tensor:
16
17
  dtype=None,
17
18
  strides=None,
18
19
  other=None,
20
+ constexpr_shape=None,
19
21
  name=None,
20
22
  source=None,
21
23
  source_dims=None,
@@ -27,10 +29,13 @@ class Tensor:
27
29
  if name is not None:
28
30
  self.name = name
29
31
  else:
30
- self.name = f"_ninetoothed_tensor_{type(self).num_instances}"
32
+ self.name = naming.auto_generate(f"tensor_{type(self).num_instances}")
31
33
 
32
34
  if ndim is not None:
33
- self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
35
+ self.shape = (
36
+ Symbol(self.size_string(i), constexpr=constexpr_shape)
37
+ for i in range(ndim)
38
+ )
34
39
  self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
35
40
  else:
36
41
  self.shape = shape
ninetoothed/torchifier.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import ast
2
2
 
3
+ import ninetoothed.naming as naming
3
4
  from ninetoothed.tensor import Tensor
4
5
 
5
6
 
@@ -9,6 +10,9 @@ class Torchifier(ast.NodeTransformer):
9
10
 
10
11
  source = node.id
11
12
 
13
+ if naming.is_constexpr(source):
14
+ return node
15
+
12
16
  def repl(match):
13
17
  return f"{match.group(1)}"
14
18
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.8.0
3
+ Version: 0.9.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -0,0 +1,11 @@
1
+ ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
+ ninetoothed/jit.py,sha256=B3q32ksKTRr7I4jLcoDXjwEx7A_Awz9DXGEmIkrtoBc,23393
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
+ ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
+ ninetoothed/tensor.py,sha256=OU6lVjzKU614mk3EN1AAgTDradbvJkyl22AXwdhxcfs,9577
7
+ ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
+ ninetoothed-0.9.0.dist-info/METADATA,sha256=LKtSbgc_mWKJ4L_8BYiRKWjoV-JKjr4hefhrkdPmrHs,7054
9
+ ninetoothed-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ ninetoothed-0.9.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
+ ninetoothed-0.9.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
- ninetoothed/jit.py,sha256=z70hQEsogfQu0cLxq5m3cOsWsVANcMRJaVv5di9vk1c,23741
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
5
- ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
- ninetoothed/tensor.py,sha256=_jM0tVgqIwZd3MJJsGVTaLCsSxpPO8JfF4qkMShhQvQ,9429
7
- ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
8
- ninetoothed-0.8.0.dist-info/METADATA,sha256=gPWYhTBH5EdeOyGnArZIEw82aFmoQchD6pxtLi6LGMA,7054
9
- ninetoothed-0.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- ninetoothed-0.8.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
- ninetoothed-0.8.0.dist-info/RECORD,,