ninetoothed 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +18 -3
- ninetoothed/symbol.py +24 -4
- ninetoothed/tensor.py +18 -2
- {ninetoothed-0.3.0.dist-info → ninetoothed-0.5.0.dist-info}/METADATA +4 -2
- ninetoothed-0.5.0.dist-info/RECORD +10 -0
- ninetoothed-0.3.0.dist-info/RECORD +0 -10
- {ninetoothed-0.3.0.dist-info → ninetoothed-0.5.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.3.0.dist-info → ninetoothed-0.5.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -137,7 +137,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
137
137
|
node.args = [
|
138
138
|
ast.arg(arg=name)
|
139
139
|
if not Symbol.is_constexpr(name)
|
140
|
-
else ast.arg(arg=name, annotation=attribute("constexpr"))
|
140
|
+
else ast.arg(arg=name, annotation=attribute("constexpr").node)
|
141
141
|
for name in non_meta_names
|
142
142
|
] + [
|
143
143
|
ast.arg(arg=name, annotation=attribute("constexpr").node)
|
@@ -178,7 +178,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
178
178
|
if isinstance(value, Tensor):
|
179
179
|
inner = value.dtype
|
180
180
|
|
181
|
-
return Symbol(inner
|
181
|
+
return Symbol(getattr(inner, node.attr)).node
|
182
182
|
|
183
183
|
self.generic_visit(node)
|
184
184
|
|
@@ -287,15 +287,30 @@ class CodeGenerator(ast.NodeTransformer):
|
|
287
287
|
)
|
288
288
|
|
289
289
|
def _generate_launch(self, params, meta):
|
290
|
+
constexpr_params = [param for param in params if Symbol.is_constexpr(param)]
|
291
|
+
constexpr_params_without_prefixes = [
|
292
|
+
Symbol.remove_prefix(param) for param in constexpr_params
|
293
|
+
]
|
294
|
+
|
290
295
|
launch = ast.FunctionDef(
|
291
296
|
name=f"launch_{self._func_def.name}",
|
292
297
|
args=ast.arguments(
|
293
298
|
posonlyargs=[],
|
294
|
-
args=[ast.arg(arg.original.name) for arg in self._args]
|
299
|
+
args=[ast.arg(arg=arg.original.name) for arg in self._args]
|
300
|
+
+ [ast.arg(arg=param) for param in constexpr_params_without_prefixes],
|
295
301
|
kwonlyargs=[],
|
296
302
|
defaults=[],
|
297
303
|
),
|
298
304
|
body=[
|
305
|
+
ast.Assign(
|
306
|
+
targets=[ast.Name(id=param, ctx=ast.Store())],
|
307
|
+
value=ast.Name(id=param_without_prefix, ctx=ast.Load()),
|
308
|
+
)
|
309
|
+
for param, param_without_prefix in zip(
|
310
|
+
constexpr_params, constexpr_params_without_prefixes
|
311
|
+
)
|
312
|
+
]
|
313
|
+
+ [
|
299
314
|
ast.Expr(
|
300
315
|
ast.Call(
|
301
316
|
func=ast.Subscript(
|
ninetoothed/symbol.py
CHANGED
@@ -137,19 +137,39 @@ class Symbol:
|
|
137
137
|
|
138
138
|
@staticmethod
|
139
139
|
def is_constexpr(name):
|
140
|
-
return name.startswith(
|
140
|
+
return name.startswith(Symbol._constexpr_prefix()) or Symbol.is_meta(name)
|
141
141
|
|
142
142
|
@staticmethod
|
143
143
|
def is_meta(name):
|
144
|
-
return name.startswith(
|
144
|
+
return name.startswith(Symbol._meta_prefix())
|
145
|
+
|
146
|
+
@staticmethod
|
147
|
+
def remove_prefix(name):
|
148
|
+
if name.startswith(Symbol._constexpr_prefix()):
|
149
|
+
return name.removeprefix(Symbol._constexpr_prefix())
|
150
|
+
|
151
|
+
if name.startswith(Symbol._meta_prefix()):
|
152
|
+
return name.removeprefix(Symbol._meta_prefix())
|
145
153
|
|
146
154
|
@staticmethod
|
147
155
|
def _create_constexpr(name):
|
148
|
-
return f"
|
156
|
+
return f"{Symbol._constexpr_prefix()}{name}"
|
149
157
|
|
150
158
|
@staticmethod
|
151
159
|
def _create_meta(name):
|
152
|
-
return f"
|
160
|
+
return f"{Symbol._meta_prefix()}{name}"
|
161
|
+
|
162
|
+
@staticmethod
|
163
|
+
def _constexpr_prefix():
|
164
|
+
return f"{Symbol._ninetoothed_prefix()}constexpr_"
|
165
|
+
|
166
|
+
@staticmethod
|
167
|
+
def _meta_prefix():
|
168
|
+
return f"{Symbol._ninetoothed_prefix()}meta_"
|
169
|
+
|
170
|
+
@staticmethod
|
171
|
+
def _ninetoothed_prefix():
|
172
|
+
return "_ninetoothed_"
|
153
173
|
|
154
174
|
|
155
175
|
class _FindAndReplacer(ast.NodeTransformer):
|
ninetoothed/tensor.py
CHANGED
@@ -24,8 +24,8 @@ class Tensor:
|
|
24
24
|
self.name = f"tensor_{type(self).num_instances}"
|
25
25
|
|
26
26
|
if ndim is not None:
|
27
|
-
self.shape =
|
28
|
-
self.strides =
|
27
|
+
self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
|
28
|
+
self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
|
29
29
|
else:
|
30
30
|
self.shape = shape
|
31
31
|
|
@@ -191,6 +191,22 @@ class Tensor:
|
|
191
191
|
|
192
192
|
return self.strides[dim]
|
193
193
|
|
194
|
+
@property
|
195
|
+
def shape(self):
|
196
|
+
return self._shape
|
197
|
+
|
198
|
+
@shape.setter
|
199
|
+
def shape(self, value):
|
200
|
+
self._shape = tuple(value)
|
201
|
+
|
202
|
+
@property
|
203
|
+
def strides(self):
|
204
|
+
return self._strides
|
205
|
+
|
206
|
+
@strides.setter
|
207
|
+
def strides(self, value):
|
208
|
+
self._strides = tuple(value)
|
209
|
+
|
194
210
|
@property
|
195
211
|
def ndim(self):
|
196
212
|
return len(self.shape)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.5.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -15,6 +15,8 @@ Description-Content-Type: text/markdown
|
|
15
15
|
|
16
16
|
# NineToothed
|
17
17
|
|
18
|
+

|
19
|
+
|
18
20
|
A domain-specific language (DSL) based on Triton but providing higher-level abstractions.
|
19
21
|
|
20
22
|
**Other language versions: [English](README.md), [简体中文](docs/README.zh.md).**
|
@@ -80,4 +82,4 @@ def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
|
80
82
|
|
81
83
|
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
|
82
84
|
|
83
|
-
With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of B
|
85
|
+
With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of $B$, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
|
@@ -0,0 +1,10 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
+
ninetoothed/jit.py,sha256=q-TRGF81rUEwV1TGDrew3ijwvzCWenR8EejZbYteZSI,16188
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/symbol.py,sha256=8Wg-JQPkVv9mMIxB1Rj4SHzOytHXPgHLkuK0BEFPDkc,5243
|
5
|
+
ninetoothed/tensor.py,sha256=UO79yYwHMfdqv32Ww2mtcl-ki1C9zInC0vBNwDtzlHU,6575
|
6
|
+
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
7
|
+
ninetoothed-0.5.0.dist-info/METADATA,sha256=ObwfQtwBk3x90adbQfiSo5wK11qUG9f4NdmunyjC--0,6785
|
8
|
+
ninetoothed-0.5.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
+
ninetoothed-0.5.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
+
ninetoothed-0.5.0.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
-
ninetoothed/jit.py,sha256=nhjZRi8_kcjWZX0eOrnxLlzJfVg5vn12f9oi0Er2ABE,15515
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/symbol.py,sha256=Bd54qcI8KQAX0JRE_wPXycswtdSofhZ6Rr5MtZcv9fo,4665
|
5
|
-
ninetoothed/tensor.py,sha256=_DrjOJ-pBvEbSNUvUoYJduLQXmuKgNcqhe4xUDMVoZw,6275
|
6
|
-
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
7
|
-
ninetoothed-0.3.0.dist-info/METADATA,sha256=CqdtfdV0eHzSwxJmFpD2IG5d4WTc6RDlpqMZue4Ml2Q,6720
|
8
|
-
ninetoothed-0.3.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
-
ninetoothed-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
-
ninetoothed-0.3.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|