ninetoothed 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import ast
2
+ import collections
2
3
  import functools
3
4
  import inspect
4
5
  import itertools
5
6
  import math
6
7
  import tempfile
7
- import textwrap
8
8
 
9
9
  from ninetoothed.language import attribute, call
10
10
  from ninetoothed.symbol import Symbol
@@ -12,6 +12,67 @@ from ninetoothed.tensor import Tensor
12
12
  from ninetoothed.torchifier import Torchifier
13
13
 
14
14
 
15
+ def jit(func):
16
+ return JIT(func)()
17
+
18
+
19
+ class JIT:
20
+ handles = collections.defaultdict(dict)
21
+
22
+ def __init__(self, func):
23
+ self.func = func
24
+
25
+ def __call__(self):
26
+ source_file = inspect.getsourcefile(self.func)
27
+ source_line = inspect.getsourcelines(self.func)[1]
28
+
29
+ if (
30
+ source_file in type(self).handles
31
+ and source_line in type(self).handles[source_file]
32
+ ):
33
+ return type(self).handles[source_file][source_line]
34
+
35
+ tree = self._get_tree()
36
+
37
+ CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
38
+ Tritonizer().visit(tree)
39
+ ast.fix_missing_locations(tree)
40
+
41
+ unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
42
+
43
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
44
+ temp_file.write(unparsed.encode("utf-8"))
45
+ temp_file_name = temp_file.name
46
+
47
+ with open(temp_file_name, "r") as temp_file:
48
+ code = compile(
49
+ source=temp_file.read(),
50
+ filename=temp_file_name,
51
+ mode="exec",
52
+ )
53
+
54
+ namespace = {}
55
+ exec(code, namespace)
56
+
57
+ handle = _Handle(
58
+ namespace[self.func.__name__],
59
+ namespace[f"launch_{self.func.__name__}"],
60
+ )
61
+
62
+ type(self).handles[source_file][source_line] = handle
63
+
64
+ return handle
65
+
66
+ def _get_tree(self):
67
+ module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
68
+
69
+ _AliasRestorer().visit(module)
70
+ finder = _FunctionDefFinder(self.func.__name__)
71
+ finder.visit(module)
72
+
73
+ return ast.Module(body=[finder.result], type_ignores=[])
74
+
75
+
15
76
  class CodeGenerator(ast.NodeTransformer):
16
77
  def __init__(self, context):
17
78
  super().__init__()
@@ -38,6 +99,18 @@ class CodeGenerator(ast.NodeTransformer):
38
99
 
39
100
  self.generic_visit(node)
40
101
 
102
+ for arg in self._args:
103
+ if not isinstance(arg, Tensor):
104
+ continue
105
+
106
+ node.body.insert(
107
+ 0,
108
+ ast.Assign(
109
+ targets=[Symbol(f"{arg.name}_ptrs").node],
110
+ value=arg.pointers().node,
111
+ ),
112
+ )
113
+
41
114
  return node
42
115
 
43
116
  def visit_arguments(self, node):
@@ -74,12 +147,12 @@ class CodeGenerator(ast.NodeTransformer):
74
147
  value = self._context[node.value.id]
75
148
 
76
149
  if isinstance(value, Tensor):
77
- if isinstance(node.slice, ast.Tuple):
78
- indices = value.indices() + tuple(node.slice.elts)
79
- else:
80
- indices = value.indices() + (node.slice,)
81
- offsets = value.offsets(indices)
82
- pointers = value.pointers(offsets)
150
+ pointers = type(self)._create_pointers(
151
+ value,
152
+ node.slice.elts
153
+ if isinstance(node.slice, ast.Tuple)
154
+ else (node.slice,),
155
+ )
83
156
 
84
157
  return call("load", pointers).node
85
158
 
@@ -104,7 +177,9 @@ class CodeGenerator(ast.NodeTransformer):
104
177
  self.generic_visit(node)
105
178
 
106
179
  if node.id in self._context and isinstance(node.ctx, ast.Load):
107
- return call("load", self._context[node.id].pointers().node).node
180
+ return call(
181
+ "load", type(self)._create_pointers(self._context[node.id], ()).node
182
+ ).node
108
183
 
109
184
  return node
110
185
 
@@ -118,7 +193,7 @@ class CodeGenerator(ast.NodeTransformer):
118
193
  return ast.Expr(
119
194
  call(
120
195
  "store",
121
- self._context[target.id].pointers().node,
196
+ type(self)._create_pointers(self._context[target.id], ()).node,
122
197
  node.value,
123
198
  ).node
124
199
  )
@@ -133,13 +208,12 @@ class CodeGenerator(ast.NodeTransformer):
133
208
  if isinstance(value, Tensor):
134
209
  self.generic_visit(node)
135
210
 
136
- indices = value.indices() + tuple(
211
+ pointers = type(self)._create_pointers(
212
+ value,
137
213
  target.slice.elts
138
214
  if isinstance(target.slice, ast.Tuple)
139
- else target.slice
215
+ else (target.slice,),
140
216
  )
141
- offsets = value.offsets(indices)
142
- pointers = value.pointers(offsets)
143
217
 
144
218
  return ast.Expr(
145
219
  call(
@@ -254,6 +328,14 @@ class CodeGenerator(ast.NodeTransformer):
254
328
 
255
329
  return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
256
330
 
331
+ @staticmethod
332
+ def _create_pointers(tensor, indices):
333
+ return Symbol(f"{tensor.name}_ptrs") + tensor.offsets(
334
+ [0 for _ in range(tensor.ndim())]
335
+ + list(indices)
336
+ + [0 for _ in range(tensor.inmost().ndim())]
337
+ )
338
+
257
339
 
258
340
  class Tritonizer(ast.NodeTransformer):
259
341
  def visit_Module(self, node):
@@ -267,8 +349,8 @@ class Tritonizer(ast.NodeTransformer):
267
349
  def visit_Name(self, node):
268
350
  self.generic_visit(node)
269
351
 
270
- if node.id == "ninetoothed":
271
- node.id = "triton"
352
+ if node.id == "ninetoothed" or "ninetoothed." in node.id:
353
+ node.id = node.id.replace("ninetoothed", "triton")
272
354
 
273
355
  return node
274
356
 
@@ -288,32 +370,71 @@ class Tritonizer(ast.NodeTransformer):
288
370
  return node
289
371
 
290
372
 
291
- def jit(func):
292
- source = textwrap.dedent(inspect.getsource(func))
293
- tree = ast.parse(source)
373
+ class _Handle:
374
+ def __init__(self, kernel, launch):
375
+ self._kernel = kernel
376
+ self._launch = launch
377
+
378
+ def __call__(self, *args, **kwargs):
379
+ return self._launch(*args, **kwargs)
380
+
381
+
382
+ class _AliasRestorer(ast.NodeTransformer):
383
+ def __init__(self):
384
+ super().__init__()
385
+
386
+ self._aliases = {}
387
+ self._redefined = set()
388
+
389
+ def visit_Import(self, node):
390
+ for alias in node.names:
391
+ if alias.asname:
392
+ self._aliases[alias.asname] = alias.name
393
+
394
+ return node
395
+
396
+ def visit_ImportFrom(self, node):
397
+ for alias in node.names:
398
+ full_name = f"{node.module}.{alias.name}"
399
+ if alias.asname:
400
+ self._aliases[alias.asname] = full_name
401
+
402
+ return node
294
403
 
295
- CodeGenerator(func.__annotations__).visit(tree)
296
- Tritonizer().visit(tree)
297
- ast.fix_missing_locations(tree)
404
+ def visit_Assign(self, node):
405
+ for target in node.targets:
406
+ if isinstance(target, ast.Name):
407
+ self._redefined.add(target.id)
408
+
409
+ return self.generic_visit(node)
410
+
411
+ def visit_FunctionDef(self, node):
412
+ original_redefined = self._redefined.copy()
413
+
414
+ self.generic_visit(node)
415
+
416
+ self._redefined = original_redefined
298
417
 
299
- unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
418
+ return node
419
+
420
+ def visit_Name(self, node):
421
+ if node.id in self._redefined:
422
+ return node
300
423
 
301
- with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
302
- temp_file.write(unparsed.encode("utf-8"))
303
- temp_file_name = temp_file.name
424
+ if node.id in self._aliases:
425
+ return ast.Name(id=self._aliases[node.id], ctx=node.ctx)
304
426
 
305
- with open(temp_file_name, "r") as temp_file:
306
- code = compile(source=temp_file.read(), filename=temp_file_name, mode="exec")
427
+ return node
307
428
 
308
- namespace = {}
309
- exec(code, namespace)
310
429
 
311
- class Handle:
312
- def __init__(self, kernel, launch):
313
- self._kernel = kernel
314
- self._launch = launch
430
+ class _FunctionDefFinder(ast.NodeVisitor):
431
+ def __init__(self, name):
432
+ self._name = name
315
433
 
316
- def __call__(self, *args, **kwargs):
317
- return self._launch(*args, **kwargs)
434
+ self.result = None
318
435
 
319
- return Handle(namespace[func.__name__], namespace[f"launch_{func.__name__}"])
436
+ def visit_FunctionDef(self, node):
437
+ if node.name == self._name:
438
+ self.result = node
439
+
440
+ self.generic_visit(node)
ninetoothed/symbol.py CHANGED
@@ -34,24 +34,47 @@ class Symbol:
34
34
  self._node.id = type(self)._create_constexpr(self._node.id)
35
35
 
36
36
  def __add__(self, other):
37
- return type(self)(
38
- ast.BinOp(left=self._node, op=ast.Add(), right=type(self)(other)._node)
39
- )
37
+ other = type(self)(other)
38
+
39
+ if isinstance(self._node, ast.Constant) and self._node.value == 0:
40
+ return other
41
+
42
+ if isinstance(other._node, ast.Constant) and other._node.value == 0:
43
+ return self
44
+
45
+ return type(self)(ast.BinOp(left=self._node, op=ast.Add(), right=other._node))
40
46
 
41
47
  def __radd__(self, other):
42
48
  return self.__add__(other)
43
49
 
44
50
  def __mul__(self, other):
45
- return type(self)(
46
- ast.BinOp(left=self._node, op=ast.Mult(), right=type(self)(other)._node)
47
- )
51
+ other = type(self)(other)
52
+
53
+ if isinstance(self._node, ast.Constant) and self._node.value == 0:
54
+ return type(self)(0)
55
+
56
+ if isinstance(other._node, ast.Constant) and other._node.value == 0:
57
+ return type(self)(0)
58
+
59
+ if isinstance(self._node, ast.Constant) and self._node.value == 1:
60
+ return other
61
+
62
+ if isinstance(other._node, ast.Constant) and other._node.value == 1:
63
+ return self
64
+
65
+ return type(self)(ast.BinOp(left=self._node, op=ast.Mult(), right=other._node))
48
66
 
49
67
  def __rmul__(self, other):
50
68
  return self.__mul__(other)
51
69
 
52
70
  def __floordiv__(self, other):
71
+ other = type(self)(other)
72
+
73
+ if isinstance(other._node, ast.Constant) and other._node.value == 1:
74
+ return self
75
+
53
76
  return type(self)(
54
- ast.BinOp(left=self._node, op=ast.FloorDiv(), right=type(self)(other)._node)
77
+ ast.BinOp(left=self._node, op=ast.FloorDiv(), right=other._node)
55
78
  )
56
79
 
57
80
  def __mod__(self, other):
ninetoothed/tensor.py CHANGED
@@ -46,7 +46,7 @@ class Tensor:
46
46
  new_size = call("cdiv", size, tile_size)
47
47
  outer_shape.append(new_size)
48
48
 
49
- new_stride = call("cdiv", stride * size, (new_size * tile_stride))
49
+ new_stride = stride * tile_size // tile_stride
50
50
  outer_strides.append(new_stride)
51
51
 
52
52
  inner_shape.append(tile_size)
@@ -103,11 +103,12 @@ class Tensor:
103
103
  indices = self.indices()
104
104
 
105
105
  if not isinstance(self.dtype, type(self)):
106
- if indices:
106
+ if len(indices) != self.ndim():
107
107
  raise IndexError("Incorrect number of indices.")
108
108
 
109
109
  return sum(
110
- self.stride(idx)
110
+ indices[idx]
111
+ * self.stride(idx)
111
112
  * call("arange", 0, self.size(idx))[
112
113
  tuple(slice(None) if i == idx else None for i in range(self.ndim()))
113
114
  ]
@@ -131,8 +132,21 @@ class Tensor:
131
132
  indices.append(index // stride)
132
133
  index %= stride
133
134
 
135
+ curr = self.dtype
136
+ while isinstance(curr, type(self)):
137
+ indices.extend(
138
+ 0 if curr is not self.inmost() else 1 for _ in range(curr.ndim())
139
+ )
140
+ curr = curr.dtype
141
+
134
142
  return tuple(indices)
135
143
 
144
+ def inmost(self):
145
+ if not isinstance(self.dtype, type(self)):
146
+ return self
147
+
148
+ return self.dtype.inmost()
149
+
136
150
  def ndim(self):
137
151
  return len(self.shape)
138
152
 
@@ -0,0 +1,79 @@
1
+ Metadata-Version: 2.3
2
+ Name: ninetoothed
3
+ Version: 0.2.0
4
+ Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
+ Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
+ Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
7
+ Author-email: Jiacheng Huang <huangjiacheng0709@outlook.com>
8
+ License-File: LICENSE
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.10
13
+ Description-Content-Type: text/markdown
14
+
15
+ # NineToothed
16
+
17
+ A domain-specific language (DSL) based on Triton but providing higher-level abstractions.
18
+
19
+ **Other language versions: [English](README.md), [简体中文](docs/README.zh.md).**
20
+
21
+ ## Installation
22
+
23
+ We can use `pip` to install `ninetoothed`.
24
+
25
+ ```shell
26
+ pip install ninetoothed
27
+ ```
28
+
29
+ After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install `triton` and a deep learning framework supported by `ninetoothed`. For trial purposes, we recommend installing `triton` and `torch`.
30
+
31
+ ## Usage
32
+
33
+ Currently, we can use the `Tensor` and `Symbol` classes in the `ninetoothed` package to perform meta-operations like `tile` and `expand` to easily construct kernel functions. Below, we will use these features to create vector addition and matrix multiplication kernel functions.
34
+
35
+ ### Vector Addition
36
+
37
+ ```python
38
+ BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)
39
+
40
+ @ninetoothed.jit
41
+ def add_kernel(
42
+ x: Tensor(1).tile((BLOCK_SIZE,)),
43
+ y: Tensor(1).tile((BLOCK_SIZE,)),
44
+ z: Tensor(1).tile((BLOCK_SIZE,)),
45
+ ):
46
+ z = x + y
47
+ ```
48
+
49
+ In this code, we first define `BLOCK_SIZE`, which is a `Symbol`. You can think of `"BLOCK_SIZE"` as its name. We see that `meta` is set to `True`, indicating to the compiler that it is a meta-parameter and its value can be determined by the compiler. The `Tensor(1)` constructs a one-dimensional tensor (vector), and `Tensor(1).tile((BLOCK_SIZE,))` means we want to create a vector and divide it into blocks of size `BLOCK_SIZE`. Suppose the size of this vector is `8192` and `BLOCK_SIZE` is `1024`, then the vector will be divided into `8` blocks, each of size `1024`.
50
+
51
+ By using type annotations, we tell the compiler that we will have three tensor parameters, which will be divided into blocks, and `x`, `y`, and `z` are these blocks. It's important to understand that `x`, `y`, and `z` are the blocks, not the tensors themselves. In the function body, `x`, `y`, and `z` are also the blocks. The rest is straightforward (only one line `z = x + y` left, haha), we add each block of `x` and `y` and store it in `z`. Since each block of the parameter tensors undergoes this operation, the addition is completed for the whole tensors as well.
52
+
53
+ ### Matrix Multiplication
54
+
55
+ ```python
56
+ BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
57
+ BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
58
+ BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)
59
+
60
+ a_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_K)).tile((1, -1))
61
+ b_tiled = Tensor(2).tile((BLOCK_SIZE_K, BLOCK_SIZE_N)).tile((-1, 1))
62
+ c_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
63
+
64
+ a_tiled = a_tiled.expand((-1, c_tiled.shape[1]))
65
+ b_tiled = b_tiled.expand((c_tiled.shape[0], -1))
66
+
67
+ @ninetoothed.jit
68
+ def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
69
+ accumulator = ninetoothed.language.zeros(
70
+ c.shape, dtype=ninetoothed.language.float32
71
+ )
72
+ for k in range(a.shape[1]):
73
+ accumulator = ninetoothed.language.dot(a[0, k], b[k, 0], accumulator)
74
+ c = accumulator.to(ninetoothed.language.float16)
75
+ ```
76
+
77
+ For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$.
78
+
79
+ With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of B, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
@@ -0,0 +1,10 @@
1
+ ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
+ ninetoothed/jit.py,sha256=hmUzkFZzsiKLgOHbsN0MAr1G1JCiyQ22cFPtmyZ1OyE,12725
3
+ ninetoothed/language.py,sha256=cSuTgi5OwmLFy-dy_AHGZzRm18wz01ByHQ2vioP1vTg,437
4
+ ninetoothed/symbol.py,sha256=I2Mc9D1w7AYAIQtyAXyDQ-FBqowVZrd-PK-JOt_SpgA,3787
5
+ ninetoothed/tensor.py,sha256=RfwYzdYASkr6usJklESm1n8RoxvYjWnPtCjIfipa2fg,5000
6
+ ninetoothed/torchifier.py,sha256=JmIVQE8r0zr_RLExsRDOGNsMu0F7v6J_o22aWqlw81k,841
7
+ ninetoothed-0.2.0.dist-info/METADATA,sha256=w6qkc2riniG0N4nDUCUkZWF8Eve3j5brBQHIWIEqLXQ,5422
8
+ ninetoothed-0.2.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ ninetoothed-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
+ ninetoothed-0.2.0.dist-info/RECORD,,
@@ -1,19 +0,0 @@
1
- Metadata-Version: 2.3
2
- Name: ninetoothed
3
- Version: 0.1.0
4
- Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
- Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
- Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
7
- Author-email: Jiacheng Huang <huangjiacheng0709@outlook.com>
8
- License-File: LICENSE
9
- Classifier: License :: OSI Approved :: Apache Software License
10
- Classifier: Operating System :: OS Independent
11
- Classifier: Programming Language :: Python :: 3
12
- Requires-Python: >=3.10
13
- Description-Content-Type: text/markdown
14
-
15
- # Nine-Toothed
16
-
17
- A domain-specific language based on Triton but providing higher-level abstraction.
18
-
19
- **Read this in other languages: [English](README.md), [简体中文](docs/README.zh.md).**
@@ -1,10 +0,0 @@
1
- ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
- ninetoothed/jit.py,sha256=mnBtsrD84usfYEozAclKBqW3Rrl1OEAolhsKRvrOTKU,9735
3
- ninetoothed/language.py,sha256=cSuTgi5OwmLFy-dy_AHGZzRm18wz01ByHQ2vioP1vTg,437
4
- ninetoothed/symbol.py,sha256=8BI4ekeLuUdHTEREvMMlAzwrJ93pqiCdSHGc38clBFA,3034
5
- ninetoothed/tensor.py,sha256=RMHgADBTdj5Q18Ttre4baq6tG_mqC4VrSn0AV6BL6VQ,4610
6
- ninetoothed/torchifier.py,sha256=JmIVQE8r0zr_RLExsRDOGNsMu0F7v6J_o22aWqlw81k,841
7
- ninetoothed-0.1.0.dist-info/METADATA,sha256=uM1Bs_zmjwgGtWJMBKejFRyiC0jO209PHS33btFMTGA,783
8
- ninetoothed-0.1.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- ninetoothed-0.1.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
- ninetoothed-0.1.0.dist-info/RECORD,,