ninetoothed 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -5,7 +5,6 @@ import inspect
5
5
  import itertools
6
6
  import math
7
7
  import tempfile
8
- import textwrap
9
8
 
10
9
  from ninetoothed.language import attribute, call
11
10
  from ninetoothed.symbol import Symbol
@@ -33,8 +32,7 @@ class JIT:
33
32
  ):
34
33
  return type(self).handles[source_file][source_line]
35
34
 
36
- source = textwrap.dedent(inspect.getsource(self.func))
37
- tree = ast.parse(source)
35
+ tree = self._get_tree()
38
36
 
39
37
  CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
40
38
  Tritonizer().visit(tree)
@@ -56,15 +54,7 @@ class JIT:
56
54
  namespace = {}
57
55
  exec(code, namespace)
58
56
 
59
- class Handle:
60
- def __init__(self, kernel, launch):
61
- self._kernel = kernel
62
- self._launch = launch
63
-
64
- def __call__(self, *args, **kwargs):
65
- return self._launch(*args, **kwargs)
66
-
67
- handle = Handle(
57
+ handle = _Handle(
68
58
  namespace[self.func.__name__],
69
59
  namespace[f"launch_{self.func.__name__}"],
70
60
  )
@@ -73,6 +63,15 @@ class JIT:
73
63
 
74
64
  return handle
75
65
 
66
+ def _get_tree(self):
67
+ module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
68
+
69
+ _AliasRestorer().visit(module)
70
+ finder = _FunctionDefFinder(self.func.__name__)
71
+ finder.visit(module)
72
+
73
+ return ast.Module(body=[finder.result], type_ignores=[])
74
+
76
75
 
77
76
  class CodeGenerator(ast.NodeTransformer):
78
77
  def __init__(self, context):
@@ -100,6 +99,18 @@ class CodeGenerator(ast.NodeTransformer):
100
99
 
101
100
  self.generic_visit(node)
102
101
 
102
+ for arg in self._args:
103
+ if not isinstance(arg, Tensor):
104
+ continue
105
+
106
+ node.body.insert(
107
+ 0,
108
+ ast.Assign(
109
+ targets=[Symbol(f"{arg.name}_ptrs").node],
110
+ value=arg.pointers().node,
111
+ ),
112
+ )
113
+
103
114
  return node
104
115
 
105
116
  def visit_arguments(self, node):
@@ -136,12 +147,12 @@ class CodeGenerator(ast.NodeTransformer):
136
147
  value = self._context[node.value.id]
137
148
 
138
149
  if isinstance(value, Tensor):
139
- if isinstance(node.slice, ast.Tuple):
140
- indices = value.indices() + tuple(node.slice.elts)
141
- else:
142
- indices = value.indices() + (node.slice,)
143
- offsets = value.offsets(indices)
144
- pointers = value.pointers(offsets)
150
+ pointers = type(self)._create_pointers(
151
+ value,
152
+ node.slice.elts
153
+ if isinstance(node.slice, ast.Tuple)
154
+ else (node.slice,),
155
+ )
145
156
 
146
157
  return call("load", pointers).node
147
158
 
@@ -166,7 +177,9 @@ class CodeGenerator(ast.NodeTransformer):
166
177
  self.generic_visit(node)
167
178
 
168
179
  if node.id in self._context and isinstance(node.ctx, ast.Load):
169
- return call("load", self._context[node.id].pointers().node).node
180
+ return call(
181
+ "load", type(self)._create_pointers(self._context[node.id], ()).node
182
+ ).node
170
183
 
171
184
  return node
172
185
 
@@ -180,7 +193,7 @@ class CodeGenerator(ast.NodeTransformer):
180
193
  return ast.Expr(
181
194
  call(
182
195
  "store",
183
- self._context[target.id].pointers().node,
196
+ type(self)._create_pointers(self._context[target.id], ()).node,
184
197
  node.value,
185
198
  ).node
186
199
  )
@@ -195,13 +208,12 @@ class CodeGenerator(ast.NodeTransformer):
195
208
  if isinstance(value, Tensor):
196
209
  self.generic_visit(node)
197
210
 
198
- indices = value.indices() + tuple(
211
+ pointers = type(self)._create_pointers(
212
+ value,
199
213
  target.slice.elts
200
214
  if isinstance(target.slice, ast.Tuple)
201
- else target.slice
215
+ else (target.slice,),
202
216
  )
203
- offsets = value.offsets(indices)
204
- pointers = value.pointers(offsets)
205
217
 
206
218
  return ast.Expr(
207
219
  call(
@@ -316,6 +328,14 @@ class CodeGenerator(ast.NodeTransformer):
316
328
 
317
329
  return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
318
330
 
331
+ @staticmethod
332
+ def _create_pointers(tensor, indices):
333
+ return Symbol(f"{tensor.name}_ptrs") + tensor.offsets(
334
+ [0 for _ in range(tensor.ndim())]
335
+ + list(indices)
336
+ + [0 for _ in range(tensor.inmost().ndim())]
337
+ )
338
+
319
339
 
320
340
  class Tritonizer(ast.NodeTransformer):
321
341
  def visit_Module(self, node):
@@ -329,8 +349,8 @@ class Tritonizer(ast.NodeTransformer):
329
349
  def visit_Name(self, node):
330
350
  self.generic_visit(node)
331
351
 
332
- if node.id == "ninetoothed":
333
- node.id = "triton"
352
+ if node.id == "ninetoothed" or "ninetoothed." in node.id:
353
+ node.id = node.id.replace("ninetoothed", "triton")
334
354
 
335
355
  return node
336
356
 
@@ -348,3 +368,73 @@ class Tritonizer(ast.NodeTransformer):
348
368
  )
349
369
 
350
370
  return node
371
+
372
+
373
+ class _Handle:
374
+ def __init__(self, kernel, launch):
375
+ self._kernel = kernel
376
+ self._launch = launch
377
+
378
+ def __call__(self, *args, **kwargs):
379
+ return self._launch(*args, **kwargs)
380
+
381
+
382
+ class _AliasRestorer(ast.NodeTransformer):
383
+ def __init__(self):
384
+ super().__init__()
385
+
386
+ self._aliases = {}
387
+ self._redefined = set()
388
+
389
+ def visit_Import(self, node):
390
+ for alias in node.names:
391
+ if alias.asname:
392
+ self._aliases[alias.asname] = alias.name
393
+
394
+ return node
395
+
396
+ def visit_ImportFrom(self, node):
397
+ for alias in node.names:
398
+ full_name = f"{node.module}.{alias.name}"
399
+ if alias.asname:
400
+ self._aliases[alias.asname] = full_name
401
+
402
+ return node
403
+
404
+ def visit_Assign(self, node):
405
+ for target in node.targets:
406
+ if isinstance(target, ast.Name):
407
+ self._redefined.add(target.id)
408
+
409
+ return self.generic_visit(node)
410
+
411
+ def visit_FunctionDef(self, node):
412
+ original_redefined = self._redefined.copy()
413
+
414
+ self.generic_visit(node)
415
+
416
+ self._redefined = original_redefined
417
+
418
+ return node
419
+
420
+ def visit_Name(self, node):
421
+ if node.id in self._redefined:
422
+ return node
423
+
424
+ if node.id in self._aliases:
425
+ return ast.Name(id=self._aliases[node.id], ctx=node.ctx)
426
+
427
+ return node
428
+
429
+
430
+ class _FunctionDefFinder(ast.NodeVisitor):
431
+ def __init__(self, name):
432
+ self._name = name
433
+
434
+ self.result = None
435
+
436
+ def visit_FunctionDef(self, node):
437
+ if node.name == self._name:
438
+ self.result = node
439
+
440
+ self.generic_visit(node)
ninetoothed/symbol.py CHANGED
@@ -34,24 +34,47 @@ class Symbol:
34
34
  self._node.id = type(self)._create_constexpr(self._node.id)
35
35
 
36
36
  def __add__(self, other):
37
- return type(self)(
38
- ast.BinOp(left=self._node, op=ast.Add(), right=type(self)(other)._node)
39
- )
37
+ other = type(self)(other)
38
+
39
+ if isinstance(self._node, ast.Constant) and self._node.value == 0:
40
+ return other
41
+
42
+ if isinstance(other._node, ast.Constant) and other._node.value == 0:
43
+ return self
44
+
45
+ return type(self)(ast.BinOp(left=self._node, op=ast.Add(), right=other._node))
40
46
 
41
47
  def __radd__(self, other):
42
48
  return self.__add__(other)
43
49
 
44
50
  def __mul__(self, other):
45
- return type(self)(
46
- ast.BinOp(left=self._node, op=ast.Mult(), right=type(self)(other)._node)
47
- )
51
+ other = type(self)(other)
52
+
53
+ if isinstance(self._node, ast.Constant) and self._node.value == 0:
54
+ return type(self)(0)
55
+
56
+ if isinstance(other._node, ast.Constant) and other._node.value == 0:
57
+ return type(self)(0)
58
+
59
+ if isinstance(self._node, ast.Constant) and self._node.value == 1:
60
+ return other
61
+
62
+ if isinstance(other._node, ast.Constant) and other._node.value == 1:
63
+ return self
64
+
65
+ return type(self)(ast.BinOp(left=self._node, op=ast.Mult(), right=other._node))
48
66
 
49
67
  def __rmul__(self, other):
50
68
  return self.__mul__(other)
51
69
 
52
70
  def __floordiv__(self, other):
71
+ other = type(self)(other)
72
+
73
+ if isinstance(other._node, ast.Constant) and other._node.value == 1:
74
+ return self
75
+
53
76
  return type(self)(
54
- ast.BinOp(left=self._node, op=ast.FloorDiv(), right=type(self)(other)._node)
77
+ ast.BinOp(left=self._node, op=ast.FloorDiv(), right=other._node)
55
78
  )
56
79
 
57
80
  def __mod__(self, other):
ninetoothed/tensor.py CHANGED
@@ -103,11 +103,12 @@ class Tensor:
103
103
  indices = self.indices()
104
104
 
105
105
  if not isinstance(self.dtype, type(self)):
106
- if indices:
106
+ if len(indices) != self.ndim():
107
107
  raise IndexError("Incorrect number of indices.")
108
108
 
109
109
  return sum(
110
- self.stride(idx)
110
+ indices[idx]
111
+ * self.stride(idx)
111
112
  * call("arange", 0, self.size(idx))[
112
113
  tuple(slice(None) if i == idx else None for i in range(self.ndim()))
113
114
  ]
@@ -131,8 +132,21 @@ class Tensor:
131
132
  indices.append(index // stride)
132
133
  index %= stride
133
134
 
135
+ curr = self.dtype
136
+ while isinstance(curr, type(self)):
137
+ indices.extend(
138
+ 0 if curr is not self.inmost() else 1 for _ in range(curr.ndim())
139
+ )
140
+ curr = curr.dtype
141
+
134
142
  return tuple(indices)
135
143
 
144
+ def inmost(self):
145
+ if not isinstance(self.dtype, type(self)):
146
+ return self
147
+
148
+ return self.dtype.inmost()
149
+
136
150
  def ndim(self):
137
151
  return len(self.shape)
138
152
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ninetoothed
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -0,0 +1,10 @@
1
+ ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
+ ninetoothed/jit.py,sha256=hmUzkFZzsiKLgOHbsN0MAr1G1JCiyQ22cFPtmyZ1OyE,12725
3
+ ninetoothed/language.py,sha256=cSuTgi5OwmLFy-dy_AHGZzRm18wz01ByHQ2vioP1vTg,437
4
+ ninetoothed/symbol.py,sha256=I2Mc9D1w7AYAIQtyAXyDQ-FBqowVZrd-PK-JOt_SpgA,3787
5
+ ninetoothed/tensor.py,sha256=RfwYzdYASkr6usJklESm1n8RoxvYjWnPtCjIfipa2fg,5000
6
+ ninetoothed/torchifier.py,sha256=JmIVQE8r0zr_RLExsRDOGNsMu0F7v6J_o22aWqlw81k,841
7
+ ninetoothed-0.2.0.dist-info/METADATA,sha256=w6qkc2riniG0N4nDUCUkZWF8Eve3j5brBQHIWIEqLXQ,5422
8
+ ninetoothed-0.2.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ ninetoothed-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
+ ninetoothed-0.2.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
- ninetoothed/jit.py,sha256=DdRdZ7DhfZwJeS7AcO_RhD9TZcCebKI55V4_6UHs3bo,10523
3
- ninetoothed/language.py,sha256=cSuTgi5OwmLFy-dy_AHGZzRm18wz01ByHQ2vioP1vTg,437
4
- ninetoothed/symbol.py,sha256=8BI4ekeLuUdHTEREvMMlAzwrJ93pqiCdSHGc38clBFA,3034
5
- ninetoothed/tensor.py,sha256=o_HLEuaBzojmbMLnbPGLcw4iqBI34TNdES3YLTagztE,4590
6
- ninetoothed/torchifier.py,sha256=JmIVQE8r0zr_RLExsRDOGNsMu0F7v6J_o22aWqlw81k,841
7
- ninetoothed-0.1.1.dist-info/METADATA,sha256=1Nv6Xcz7CrpEUrzAYH93bYVX8GfPtHwzj4yofeaoJro,5422
8
- ninetoothed-0.1.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- ninetoothed-0.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
- ninetoothed-0.1.1.dist-info/RECORD,,