ninetoothed 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -137,7 +137,7 @@ class CodeGenerator(ast.NodeTransformer):
137
137
  node.args = [
138
138
  ast.arg(arg=name)
139
139
  if not Symbol.is_constexpr(name)
140
- else ast.arg(arg=name, annotation=attribute("constexpr"))
140
+ else ast.arg(arg=name, annotation=attribute("constexpr").node)
141
141
  for name in non_meta_names
142
142
  ] + [
143
143
  ast.arg(arg=name, annotation=attribute("constexpr").node)
@@ -178,7 +178,7 @@ class CodeGenerator(ast.NodeTransformer):
178
178
  if isinstance(value, Tensor):
179
179
  inner = value.dtype
180
180
 
181
- return Symbol(inner.__dict__[node.attr]).node
181
+ return Symbol(getattr(inner, node.attr)).node
182
182
 
183
183
  self.generic_visit(node)
184
184
 
@@ -287,15 +287,30 @@ class CodeGenerator(ast.NodeTransformer):
287
287
  )
288
288
 
289
289
  def _generate_launch(self, params, meta):
290
+ constexpr_params = [param for param in params if Symbol.is_constexpr(param)]
291
+ constexpr_params_without_prefixes = [
292
+ Symbol.remove_prefix(param) for param in constexpr_params
293
+ ]
294
+
290
295
  launch = ast.FunctionDef(
291
296
  name=f"launch_{self._func_def.name}",
292
297
  args=ast.arguments(
293
298
  posonlyargs=[],
294
- args=[ast.arg(arg.original.name) for arg in self._args],
299
+ args=[ast.arg(arg=arg.original.name) for arg in self._args]
300
+ + [ast.arg(arg=param) for param in constexpr_params_without_prefixes],
295
301
  kwonlyargs=[],
296
302
  defaults=[],
297
303
  ),
298
304
  body=[
305
+ ast.Assign(
306
+ targets=[ast.Name(id=param, ctx=ast.Store())],
307
+ value=ast.Name(id=param_without_prefix, ctx=ast.Load()),
308
+ )
309
+ for param, param_without_prefix in zip(
310
+ constexpr_params, constexpr_params_without_prefixes
311
+ )
312
+ ]
313
+ + [
299
314
  ast.Expr(
300
315
  ast.Call(
301
316
  func=ast.Subscript(
ninetoothed/symbol.py CHANGED
@@ -137,19 +137,39 @@ class Symbol:
137
137
 
138
138
  @staticmethod
139
139
  def is_constexpr(name):
140
- return name.startswith("_ninetoothed_constexpr_") or Symbol.is_meta(name)
140
+ return name.startswith(Symbol._constexpr_prefix()) or Symbol.is_meta(name)
141
141
 
142
142
  @staticmethod
143
143
  def is_meta(name):
144
- return name.startswith("_ninetoothed_meta_")
144
+ return name.startswith(Symbol._meta_prefix())
145
+
146
+ @staticmethod
147
+ def remove_prefix(name):
148
+ if name.startswith(Symbol._constexpr_prefix()):
149
+ return name.removeprefix(Symbol._constexpr_prefix())
150
+
151
+ if name.startswith(Symbol._meta_prefix()):
152
+ return name.removeprefix(Symbol._meta_prefix())
145
153
 
146
154
  @staticmethod
147
155
  def _create_constexpr(name):
148
- return f"_ninetoothed_constexpr_{name}"
156
+ return f"{Symbol._constexpr_prefix()}{name}"
149
157
 
150
158
  @staticmethod
151
159
  def _create_meta(name):
152
- return f"_ninetoothed_meta_{name}"
160
+ return f"{Symbol._meta_prefix()}{name}"
161
+
162
+ @staticmethod
163
+ def _constexpr_prefix():
164
+ return f"{Symbol._ninetoothed_prefix()}constexpr_"
165
+
166
+ @staticmethod
167
+ def _meta_prefix():
168
+ return f"{Symbol._ninetoothed_prefix()}meta_"
169
+
170
+ @staticmethod
171
+ def _ninetoothed_prefix():
172
+ return "_ninetoothed_"
153
173
 
154
174
 
155
175
  class _FindAndReplacer(ast.NodeTransformer):
ninetoothed/tensor.py CHANGED
@@ -24,8 +24,8 @@ class Tensor:
24
24
  self.name = f"tensor_{type(self).num_instances}"
25
25
 
26
26
  if ndim is not None:
27
- self.shape = [Symbol(self.size_string(i)) for i in range(ndim)]
28
- self.strides = [Symbol(self.stride_string(i)) for i in range(ndim)]
27
+ self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
28
+ self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
29
29
  else:
30
30
  self.shape = shape
31
31
 
@@ -191,6 +191,22 @@ class Tensor:
191
191
 
192
192
  return self.strides[dim]
193
193
 
194
+ @property
195
+ def shape(self):
196
+ return self._shape
197
+
198
+ @shape.setter
199
+ def shape(self, value):
200
+ self._shape = tuple(value)
201
+
202
+ @property
203
+ def strides(self):
204
+ return self._strides
205
+
206
+ @strides.setter
207
+ def strides(self, value):
208
+ self._strides = tuple(value)
209
+
194
210
  @property
195
211
  def ndim(self):
196
212
  return len(self.shape)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ninetoothed
3
- Version: 0.3.0
3
+ Version: 0.5.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -15,6 +15,8 @@ Description-Content-Type: text/markdown
15
15
 
16
16
  # NineToothed
17
17
 
18
+ ![NineToothed Logo](docs/source/_static/ninetoothed-logo.png)
19
+
18
20
  A domain-specific language (DSL) based on Triton but providing higher-level abstractions.
19
21
 
20
22
  **Other language versions: [English](README.md), [简体中文](docs/README.zh.md).**
@@ -80,4 +82,4 @@ def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
80
82
 
81
83
  For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
82
84
 
83
- With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of B, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
85
+ With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of $B$, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
@@ -0,0 +1,10 @@
1
+ ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
+ ninetoothed/jit.py,sha256=q-TRGF81rUEwV1TGDrew3ijwvzCWenR8EejZbYteZSI,16188
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/symbol.py,sha256=8Wg-JQPkVv9mMIxB1Rj4SHzOytHXPgHLkuK0BEFPDkc,5243
5
+ ninetoothed/tensor.py,sha256=UO79yYwHMfdqv32Ww2mtcl-ki1C9zInC0vBNwDtzlHU,6575
6
+ ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
7
+ ninetoothed-0.5.0.dist-info/METADATA,sha256=ObwfQtwBk3x90adbQfiSo5wK11qUG9f4NdmunyjC--0,6785
8
+ ninetoothed-0.5.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ ninetoothed-0.5.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
+ ninetoothed-0.5.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
- ninetoothed/jit.py,sha256=nhjZRi8_kcjWZX0eOrnxLlzJfVg5vn12f9oi0Er2ABE,15515
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/symbol.py,sha256=Bd54qcI8KQAX0JRE_wPXycswtdSofhZ6Rr5MtZcv9fo,4665
5
- ninetoothed/tensor.py,sha256=_DrjOJ-pBvEbSNUvUoYJduLQXmuKgNcqhe4xUDMVoZw,6275
6
- ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
7
- ninetoothed-0.3.0.dist-info/METADATA,sha256=CqdtfdV0eHzSwxJmFpD2IG5d4WTc6RDlpqMZue4Ml2Q,6720
8
- ninetoothed-0.3.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- ninetoothed-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
- ninetoothed-0.3.0.dist-info/RECORD,,