ninetoothed 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -1,9 +1,11 @@
1
1
  import ast
2
2
  import collections
3
3
  import functools
4
+ import importlib.util
4
5
  import inspect
5
6
  import itertools
6
7
  import math
8
+ import sys
7
9
  import tempfile
8
10
 
9
11
  import triton
@@ -41,24 +43,20 @@ class JIT:
41
43
  ast.fix_missing_locations(tree)
42
44
 
43
45
  unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
46
+ dependencies = self._find_dependencies()
47
+ source = "\n\n".join((unparsed, dependencies)).strip()
44
48
 
45
49
  with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
46
- temp_file.write(unparsed.encode("utf-8"))
50
+ temp_file.write(source.encode("utf-8"))
47
51
  temp_file_name = temp_file.name
48
52
 
49
- with open(temp_file_name, "r") as temp_file:
50
- code = compile(
51
- source=temp_file.read(),
52
- filename=temp_file_name,
53
- mode="exec",
54
- )
55
-
56
- namespace = {}
57
- exec(code, namespace)
53
+ module = type(self)._import_from_path(temp_file_name, temp_file_name)
54
+ module_vars = vars(module)
58
55
 
59
56
  handle = _Handle(
60
- namespace[self.func.__name__],
61
- namespace[f"launch_{self.func.__name__}"],
57
+ module_vars[self.func.__name__],
58
+ module_vars[f"launch_{self.func.__name__}"],
59
+ source,
62
60
  )
63
61
 
64
62
  type(self).handles[source_file][source_line] = handle
@@ -74,6 +72,24 @@ class JIT:
74
72
 
75
73
  return ast.Module(body=[finder.result], type_ignores=[])
76
74
 
75
+ def _find_dependencies(self):
76
+ dependencies = set()
77
+
78
+ for obj in self.func.__globals__.values():
79
+ if isinstance(obj, triton.runtime.JITFunction):
80
+ dependencies.add(obj.src)
81
+
82
+ return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)
83
+
84
+ @staticmethod
85
+ def _import_from_path(module_name, file_path):
86
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
87
+ module = importlib.util.module_from_spec(spec)
88
+ sys.modules[module_name] = module
89
+ spec.loader.exec_module(module)
90
+
91
+ return module
92
+
77
93
 
78
94
  class CodeGenerator(ast.NodeTransformer):
79
95
  def __init__(self, context):
@@ -102,7 +118,7 @@ class CodeGenerator(ast.NodeTransformer):
102
118
  self.generic_visit(node)
103
119
 
104
120
  for arg in self._args:
105
- if not isinstance(arg, Tensor):
121
+ if not isinstance(arg, Tensor) or arg.ndim == 0:
106
122
  continue
107
123
 
108
124
  offsets = arg.offsets()
@@ -178,7 +194,7 @@ class CodeGenerator(ast.NodeTransformer):
178
194
  if isinstance(value, Tensor):
179
195
  inner = value.dtype
180
196
 
181
- return Symbol(inner.__dict__[node.attr]).node
197
+ return Symbol(getattr(inner, node.attr)).node
182
198
 
183
199
  self.generic_visit(node)
184
200
 
@@ -355,6 +371,9 @@ class CodeGenerator(ast.NodeTransformer):
355
371
 
356
372
  @staticmethod
357
373
  def _generate_load(tensor, intermediate_indices=()):
374
+ if tensor.ndim == 0:
375
+ return Symbol(tensor.original.name).node
376
+
358
377
  pointers, mask = CodeGenerator._generate_pointers_and_mask(
359
378
  tensor, intermediate_indices
360
379
  )
@@ -459,9 +478,10 @@ class Tritonizer(ast.NodeTransformer):
459
478
 
460
479
 
461
480
  class _Handle:
462
- def __init__(self, kernel, launch):
481
+ def __init__(self, kernel, launch, source):
463
482
  self._kernel = kernel
464
483
  self._launch = launch
484
+ self._source = source
465
485
 
466
486
  def __call__(self, *args, **kwargs):
467
487
  return self._launch(*args, **kwargs)
ninetoothed/tensor.py CHANGED
@@ -21,11 +21,11 @@ class Tensor:
21
21
 
22
22
  self.dtype = dtype
23
23
 
24
- self.name = f"tensor_{type(self).num_instances}"
24
+ self.name = f"_ninetoothed_tensor_{type(self).num_instances}"
25
25
 
26
26
  if ndim is not None:
27
- self.shape = [Symbol(self.size_string(i)) for i in range(ndim)]
28
- self.strides = [Symbol(self.stride_string(i)) for i in range(ndim)]
27
+ self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
28
+ self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
29
29
  else:
30
30
  self.shape = shape
31
31
 
@@ -103,6 +103,9 @@ class Tensor:
103
103
  )
104
104
 
105
105
  def names(self):
106
+ if self.ndim == 0:
107
+ return {self.original.name}
108
+
106
109
  return (
107
110
  {self.original.pointer_string()}
108
111
  | {
@@ -191,6 +194,22 @@ class Tensor:
191
194
 
192
195
  return self.strides[dim]
193
196
 
197
+ @property
198
+ def shape(self):
199
+ return self._shape
200
+
201
+ @shape.setter
202
+ def shape(self, value):
203
+ self._shape = tuple(value)
204
+
205
+ @property
206
+ def strides(self):
207
+ return self._strides
208
+
209
+ @strides.setter
210
+ def strides(self, value):
211
+ self._strides = tuple(value)
212
+
194
213
  @property
195
214
  def ndim(self):
196
215
  return len(self.shape)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ninetoothed
3
- Version: 0.4.0
3
+ Version: 0.6.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -15,6 +15,8 @@ Description-Content-Type: text/markdown
15
15
 
16
16
  # NineToothed
17
17
 
18
+ ![NineToothed Logo](docs/source/_static/ninetoothed-logo.png)
19
+
18
20
  A domain-specific language (DSL) based on Triton but providing higher-level abstractions.
19
21
 
20
22
  **Other language versions: [English](README.md), [简体中文](docs/README.zh.md).**
@@ -80,4 +82,4 @@ def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
80
82
 
81
83
  For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
82
84
 
83
- With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of B, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
85
+ With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of $B$, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
@@ -0,0 +1,10 @@
1
+ ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
+ ninetoothed/jit.py,sha256=5gNp4HixCkural_Ns3DxwT4LL3OUcG0ECj4NLjb-EYk,16959
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/symbol.py,sha256=8Wg-JQPkVv9mMIxB1Rj4SHzOytHXPgHLkuK0BEFPDkc,5243
5
+ ninetoothed/tensor.py,sha256=L-9LhwnM4uRtRvj3tqrzerUijEfKeTQvFBcmS1hQilI,6656
6
+ ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
7
+ ninetoothed-0.6.0.dist-info/METADATA,sha256=zvY4nvKt7R8kWDYrGnApem_C07trLgOj1-7zXPfqD9U,6785
8
+ ninetoothed-0.6.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ ninetoothed-0.6.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
+ ninetoothed-0.6.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
- ninetoothed/jit.py,sha256=ECjaHcrVNj1SBxoMdzjGi5iDp3rtv2jUiHjvK0eU6Cs,16188
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/symbol.py,sha256=8Wg-JQPkVv9mMIxB1Rj4SHzOytHXPgHLkuK0BEFPDkc,5243
5
- ninetoothed/tensor.py,sha256=_DrjOJ-pBvEbSNUvUoYJduLQXmuKgNcqhe4xUDMVoZw,6275
6
- ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
7
- ninetoothed-0.4.0.dist-info/METADATA,sha256=Wgg0CP-j8VkiJWMpyOLOL7C1kVLkeF4OoZD6eyZsgLQ,6720
8
- ninetoothed-0.4.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- ninetoothed-0.4.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
- ninetoothed-0.4.0.dist-info/RECORD,,