ninetoothed 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/__init__.py +2 -1
- ninetoothed/aot.py +9 -3
- ninetoothed/generation.py +239 -27
- ninetoothed/jit.py +46 -4
- ninetoothed/make.py +22 -3
- ninetoothed/symbol.py +65 -2
- ninetoothed/tensor.py +18 -7
- ninetoothed/utils.py +12 -0
- ninetoothed/visualization.py +10 -4
- {ninetoothed-0.15.0.dist-info → ninetoothed-0.16.0.dist-info}/METADATA +2 -1
- ninetoothed-0.16.0.dist-info/RECORD +18 -0
- ninetoothed-0.15.0.dist-info/RECORD +0 -17
- {ninetoothed-0.15.0.dist-info → ninetoothed-0.16.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.15.0.dist-info → ninetoothed-0.16.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/__init__.py
CHANGED
@@ -13,12 +13,13 @@ from ninetoothed.dtype import (
|
|
13
13
|
)
|
14
14
|
from ninetoothed.jit import jit
|
15
15
|
from ninetoothed.make import make
|
16
|
-
from ninetoothed.symbol import Symbol
|
16
|
+
from ninetoothed.symbol import Symbol, block_size
|
17
17
|
from ninetoothed.tensor import Tensor
|
18
18
|
|
19
19
|
__all__ = [
|
20
20
|
"Symbol",
|
21
21
|
"Tensor",
|
22
|
+
"block_size",
|
22
23
|
"float16",
|
23
24
|
"float32",
|
24
25
|
"float64",
|
ninetoothed/aot.py
CHANGED
@@ -4,7 +4,7 @@ import subprocess
|
|
4
4
|
import tempfile
|
5
5
|
import uuid
|
6
6
|
|
7
|
-
from ninetoothed.dtype import int64
|
7
|
+
from ninetoothed.dtype import int64
|
8
8
|
from ninetoothed.generation import CACHE_DIR, CodeGenerator
|
9
9
|
from ninetoothed.tensor import Tensor
|
10
10
|
|
@@ -36,7 +36,13 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
36
36
|
|
37
37
|
code_generator = CodeGenerator()
|
38
38
|
source_file = code_generator(
|
39
|
-
func,
|
39
|
+
func,
|
40
|
+
caller=caller,
|
41
|
+
kernel_name=kernel_name,
|
42
|
+
num_warps=num_warps,
|
43
|
+
num_stages=num_stages,
|
44
|
+
max_num_configs=None,
|
45
|
+
prettify=False,
|
40
46
|
)
|
41
47
|
|
42
48
|
tensors = code_generator.tensors
|
@@ -55,7 +61,7 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
55
61
|
|
56
62
|
param_types.append(f"*{dtype}")
|
57
63
|
elif Tensor.size_pattern().fullmatch(param):
|
58
|
-
param_types.append(
|
64
|
+
param_types.append(int64)
|
59
65
|
elif Tensor.stride_pattern().fullmatch(param):
|
60
66
|
param_types.append(int64)
|
61
67
|
|
ninetoothed/generation.py
CHANGED
@@ -5,11 +5,20 @@ import functools
|
|
5
5
|
import hashlib
|
6
6
|
import inspect
|
7
7
|
import itertools
|
8
|
+
import json
|
8
9
|
import math
|
10
|
+
import os
|
9
11
|
import pathlib
|
12
|
+
import random
|
13
|
+
import shutil
|
10
14
|
import subprocess
|
15
|
+
import tempfile
|
16
|
+
import time
|
17
|
+
import uuid
|
11
18
|
|
19
|
+
import sympy
|
12
20
|
import triton
|
21
|
+
import triton.language as tl
|
13
22
|
|
14
23
|
import ninetoothed.naming as naming
|
15
24
|
from ninetoothed.cudaifier import Cudaifier
|
@@ -19,19 +28,47 @@ from ninetoothed.tensor import Tensor
|
|
19
28
|
from ninetoothed.torchifier import Torchifier
|
20
29
|
|
21
30
|
CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
|
31
|
+
CACHE_DIR.mkdir(exist_ok=True)
|
22
32
|
|
23
33
|
|
24
34
|
class CodeGenerator(ast.NodeTransformer):
|
25
35
|
def __init__(self):
|
26
36
|
super().__init__()
|
27
37
|
|
28
|
-
|
38
|
+
cache_file = CACHE_DIR / "code_generator_cache.json"
|
29
39
|
|
30
|
-
|
40
|
+
log2_min_num_elements = 4
|
31
41
|
|
32
|
-
|
42
|
+
if cache_file.exists():
|
43
|
+
with open(cache_file) as f:
|
44
|
+
cache = json.load(f)
|
33
45
|
|
34
|
-
|
46
|
+
log2_max_num_elements = cache["log2_max_num_elements"]
|
47
|
+
else:
|
48
|
+
log2_max_num_elements = _determine_log2_max_num_elements_per_block(
|
49
|
+
log2_min_num_elements
|
50
|
+
)
|
51
|
+
|
52
|
+
cache = {"log2_max_num_elements": log2_max_num_elements}
|
53
|
+
|
54
|
+
with open(cache_file, "w") as f:
|
55
|
+
json.dump(cache, f, indent=4)
|
56
|
+
f.write("\n")
|
57
|
+
|
58
|
+
self._min_num_elements = 2**log2_min_num_elements
|
59
|
+
|
60
|
+
self._max_num_elements = 2**log2_max_num_elements
|
61
|
+
|
62
|
+
def __call__(
|
63
|
+
self,
|
64
|
+
func,
|
65
|
+
caller,
|
66
|
+
kernel_name,
|
67
|
+
num_warps,
|
68
|
+
num_stages,
|
69
|
+
max_num_configs,
|
70
|
+
prettify,
|
71
|
+
):
|
35
72
|
def _get_tree(func):
|
36
73
|
module = ast.parse(inspect.getsource(inspect.getmodule(func)))
|
37
74
|
|
@@ -63,6 +100,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
63
100
|
|
64
101
|
self._caller = caller
|
65
102
|
|
103
|
+
self._num_wraps = num_warps
|
104
|
+
|
105
|
+
self._num_stages = num_stages
|
106
|
+
|
107
|
+
self._max_num_configs = max_num_configs
|
108
|
+
|
66
109
|
self._context = inspect.get_annotations(func)
|
67
110
|
|
68
111
|
self._args = list(self._context.values())
|
@@ -82,6 +125,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
82
125
|
dependencies = _find_dependencies(func)
|
83
126
|
source = "\n\n".join((unparsed, dependencies)).strip()
|
84
127
|
source = source.replace(func.__name__, kernel_name)
|
128
|
+
source += "\n"
|
85
129
|
|
86
130
|
if prettify:
|
87
131
|
for original, simplified in name_collector.simplified_names.items():
|
@@ -93,9 +137,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
93
137
|
)
|
94
138
|
|
95
139
|
digest = hashlib.sha256(source.encode("utf-8")).hexdigest()
|
96
|
-
|
97
|
-
cache_dir.mkdir(exist_ok=True)
|
98
|
-
cache_file = cache_dir / f"{digest}.py"
|
140
|
+
cache_file = CACHE_DIR / f"{digest}.py"
|
99
141
|
|
100
142
|
if not cache_file.exists():
|
101
143
|
with open(cache_file, "w", encoding="utf-8") as f:
|
@@ -110,11 +152,15 @@ class CodeGenerator(ast.NodeTransformer):
|
|
110
152
|
def visit_Module(self, node):
|
111
153
|
self.generic_visit(node)
|
112
154
|
|
113
|
-
|
155
|
+
if self._autotune is not None:
|
156
|
+
func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
|
157
|
+
|
158
|
+
node.body.append(
|
159
|
+
ast.parse(
|
160
|
+
f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}"
|
161
|
+
)
|
162
|
+
)
|
114
163
|
|
115
|
-
node.body.append(
|
116
|
-
ast.parse(f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}")
|
117
|
-
)
|
118
164
|
node.body.append(self._launch)
|
119
165
|
|
120
166
|
return node
|
@@ -136,8 +182,13 @@ class CodeGenerator(ast.NodeTransformer):
|
|
136
182
|
def visit_arguments(self, node):
|
137
183
|
self.generic_visit(node)
|
138
184
|
|
139
|
-
|
140
|
-
|
185
|
+
symbols = {
|
186
|
+
name.node.id: name
|
187
|
+
for arg in self._args
|
188
|
+
for name in arg.names()
|
189
|
+
if name != "ninetoothed"
|
190
|
+
}
|
191
|
+
names = symbols.keys()
|
141
192
|
meta_names = {name for name in names if naming.is_meta(name)}
|
142
193
|
non_meta_names = {name for name in names if name not in meta_names}
|
143
194
|
non_meta_names |= {
|
@@ -146,6 +197,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
146
197
|
if naming.is_constexpr(name)
|
147
198
|
}
|
148
199
|
|
200
|
+
self._symbols = symbols
|
201
|
+
|
149
202
|
non_meta_names = sorted(non_meta_names)
|
150
203
|
meta_names = sorted(meta_names)
|
151
204
|
|
@@ -160,6 +213,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
160
213
|
]
|
161
214
|
|
162
215
|
self._autotune = self._generate_autotune(non_meta_names, meta_names)
|
216
|
+
|
217
|
+
if self._autotune is not None:
|
218
|
+
self._func_name = self._func_name_with_auto_tuning
|
219
|
+
else:
|
220
|
+
self._func_name = self._func_def.name
|
221
|
+
|
163
222
|
self._func_def.decorator_list = [Symbol("triton.jit").node]
|
164
223
|
|
165
224
|
self._launch = self._generate_launch(non_meta_names, meta_names)
|
@@ -243,12 +302,69 @@ class CodeGenerator(ast.NodeTransformer):
|
|
243
302
|
return isinstance(node, ast.Name) and node.id in self._context
|
244
303
|
|
245
304
|
def _generate_autotune(self, params, meta):
|
246
|
-
|
247
|
-
|
248
|
-
|
305
|
+
inequalities = True
|
306
|
+
|
307
|
+
for arg in self._args:
|
308
|
+
if arg.ndim == 0:
|
309
|
+
continue
|
310
|
+
|
311
|
+
num_elements = sympy.simplify(str(math.prod(arg.innermost().shape)))
|
312
|
+
|
313
|
+
inequalities &= num_elements <= self._max_num_elements
|
314
|
+
inequalities &= num_elements >= self._min_num_elements
|
315
|
+
|
316
|
+
values_of_meta_params = []
|
317
|
+
|
318
|
+
for param in meta:
|
319
|
+
symbol = self._symbols[param]
|
320
|
+
|
321
|
+
values = range(symbol.lower_bound, symbol.upper_bound + 1)
|
322
|
+
|
323
|
+
if symbol.power_of_two:
|
324
|
+
values = tuple(value for value in values if value & (value - 1) == 0)
|
325
|
+
else:
|
326
|
+
values = tuple(values)
|
327
|
+
|
328
|
+
values_of_meta_params.append(values)
|
329
|
+
|
330
|
+
max_values_of_non_meta_params = {}
|
331
|
+
|
332
|
+
for free_symbol in inequalities.free_symbols:
|
333
|
+
symbol_str = str(free_symbol)
|
334
|
+
|
335
|
+
if symbol_str in meta:
|
336
|
+
continue
|
337
|
+
|
338
|
+
symbol = self._symbols[symbol_str]
|
339
|
+
|
340
|
+
max_values_of_non_meta_params[symbol_str] = symbol.upper_bound
|
341
|
+
|
342
|
+
block_size_configs = []
|
343
|
+
|
344
|
+
for values in itertools.product(*values_of_meta_params):
|
345
|
+
config = {param: value for param, value in zip(meta, values)}
|
346
|
+
|
347
|
+
if sympy.logic.simplify_logic(
|
348
|
+
inequalities.subs(config | max_values_of_non_meta_params)
|
349
|
+
):
|
350
|
+
block_size_configs.append(config)
|
351
|
+
|
352
|
+
if isinstance(self._num_wraps, collections.abc.Iterable):
|
353
|
+
num_warps_configs = self._num_wraps
|
354
|
+
else:
|
355
|
+
num_warps_configs = (self._num_wraps,)
|
356
|
+
|
357
|
+
if isinstance(self._num_stages, collections.abc.Iterable):
|
358
|
+
num_stages_configs = self._num_stages
|
359
|
+
else:
|
360
|
+
num_stages_configs = (self._num_stages,)
|
249
361
|
|
250
|
-
|
251
|
-
|
362
|
+
compiler_configs = tuple(
|
363
|
+
{"num_warps": num_warps, "num_stages": num_stages}
|
364
|
+
for num_warps, num_stages in itertools.product(
|
365
|
+
num_warps_configs, num_stages_configs
|
366
|
+
)
|
367
|
+
)
|
252
368
|
|
253
369
|
configs = [
|
254
370
|
ast.Call(
|
@@ -259,19 +375,38 @@ class CodeGenerator(ast.NodeTransformer):
|
|
259
375
|
),
|
260
376
|
args=[
|
261
377
|
ast.Dict(
|
262
|
-
keys=[
|
263
|
-
|
378
|
+
keys=[
|
379
|
+
ast.Constant(value=param)
|
380
|
+
for param in block_size_config.keys()
|
381
|
+
],
|
382
|
+
values=[
|
383
|
+
ast.Constant(value=value)
|
384
|
+
for value in block_size_config.values()
|
385
|
+
],
|
264
386
|
)
|
265
387
|
],
|
266
388
|
keywords=[
|
267
|
-
ast.keyword(
|
268
|
-
|
389
|
+
ast.keyword(
|
390
|
+
arg="num_warps",
|
391
|
+
value=ast.Constant(value=compiler_config["num_warps"]),
|
392
|
+
),
|
393
|
+
ast.keyword(
|
394
|
+
arg="num_stages",
|
395
|
+
value=ast.Constant(value=compiler_config["num_stages"]),
|
396
|
+
),
|
269
397
|
],
|
270
398
|
)
|
271
|
-
for
|
272
|
-
|
399
|
+
for block_size_config, compiler_config in itertools.product(
|
400
|
+
block_size_configs, compiler_configs
|
401
|
+
)
|
273
402
|
]
|
274
403
|
|
404
|
+
if self._max_num_configs is not None and len(configs) > self._max_num_configs:
|
405
|
+
configs = random.sample(configs, k=self._max_num_configs)
|
406
|
+
|
407
|
+
if not configs:
|
408
|
+
return None
|
409
|
+
|
275
410
|
return ast.Call(
|
276
411
|
func=ast.Attribute(
|
277
412
|
value=ast.Name(id="ninetoothed", ctx=ast.Load()),
|
@@ -357,9 +492,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
357
492
|
ast.Expr(
|
358
493
|
ast.Call(
|
359
494
|
func=ast.Subscript(
|
360
|
-
value=ast.Name(
|
361
|
-
id=self._func_name_with_auto_tuning, ctx=ast.Load()
|
362
|
-
),
|
495
|
+
value=ast.Name(id=self._func_name, ctx=ast.Load()),
|
363
496
|
slice=self._generate_grid(),
|
364
497
|
ctx=ast.Load(),
|
365
498
|
),
|
@@ -993,3 +1126,82 @@ class _FunctionDefFinder(ast.NodeVisitor):
|
|
993
1126
|
self.result = node
|
994
1127
|
|
995
1128
|
self.generic_visit(node)
|
1129
|
+
|
1130
|
+
|
1131
|
+
def _determine_log2_max_num_elements_per_block(
|
1132
|
+
min_exponent, max_exponent=30, num_iterations=3
|
1133
|
+
):
|
1134
|
+
_profile_pseudo_add_kernel(1)
|
1135
|
+
|
1136
|
+
for n in range(min_exponent, max_exponent + 1):
|
1137
|
+
elapsed_time = 0
|
1138
|
+
|
1139
|
+
for _ in range(num_iterations):
|
1140
|
+
elapsed_time += _profile_pseudo_add_kernel(2**n)
|
1141
|
+
|
1142
|
+
average_elapsed_time = elapsed_time / num_iterations
|
1143
|
+
|
1144
|
+
if average_elapsed_time >= 1:
|
1145
|
+
return n - 1
|
1146
|
+
|
1147
|
+
|
1148
|
+
def _profile_pseudo_add_kernel(block_size):
|
1149
|
+
cache_dir = triton.runtime.cache.default_cache_dir()
|
1150
|
+
os.makedirs(cache_dir, exist_ok=True)
|
1151
|
+
|
1152
|
+
with tempfile.TemporaryDirectory() as backup_dir:
|
1153
|
+
backup_path = os.path.join(backup_dir, str(uuid.uuid4()))
|
1154
|
+
|
1155
|
+
if os.path.exists(backup_path):
|
1156
|
+
shutil.rmtree(backup_path)
|
1157
|
+
|
1158
|
+
shutil.move(cache_dir, backup_path)
|
1159
|
+
|
1160
|
+
try:
|
1161
|
+
start_time = time.time()
|
1162
|
+
|
1163
|
+
_run_pseudo_add_kernel(block_size)
|
1164
|
+
|
1165
|
+
end_time = time.time()
|
1166
|
+
|
1167
|
+
elapsed_time = end_time - start_time
|
1168
|
+
finally:
|
1169
|
+
if os.path.exists(cache_dir):
|
1170
|
+
shutil.rmtree(cache_dir)
|
1171
|
+
|
1172
|
+
shutil.move(backup_path, cache_dir)
|
1173
|
+
|
1174
|
+
return elapsed_time
|
1175
|
+
|
1176
|
+
|
1177
|
+
def _run_pseudo_add_kernel(block_size):
|
1178
|
+
@triton.jit
|
1179
|
+
def kernel(a_ptr, b_ptr, c_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
|
1180
|
+
pid = tl.program_id(0)
|
1181
|
+
|
1182
|
+
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
1183
|
+
mask = offs < num_elements
|
1184
|
+
|
1185
|
+
a = tl.load(a_ptr + offs, mask=mask)
|
1186
|
+
b = tl.load(b_ptr + offs, mask=mask)
|
1187
|
+
|
1188
|
+
c = a + b
|
1189
|
+
|
1190
|
+
tl.store(c_ptr + offs, c, mask=mask)
|
1191
|
+
|
1192
|
+
num_elements = 0
|
1193
|
+
shape = (num_elements,)
|
1194
|
+
dtype = tl.float32
|
1195
|
+
|
1196
|
+
a = Tensor(shape=shape, dtype=dtype)
|
1197
|
+
b = Tensor(shape=shape, dtype=dtype)
|
1198
|
+
c = Tensor(shape=shape, dtype=dtype)
|
1199
|
+
|
1200
|
+
def data_ptr():
|
1201
|
+
return 0
|
1202
|
+
|
1203
|
+
a.data_ptr = data_ptr
|
1204
|
+
b.data_ptr = data_ptr
|
1205
|
+
c.data_ptr = data_ptr
|
1206
|
+
|
1207
|
+
kernel[(1,)](a, b, c, num_elements, block_size)
|
ninetoothed/jit.py
CHANGED
@@ -4,12 +4,25 @@ import sys
|
|
4
4
|
from ninetoothed.generation import CodeGenerator
|
5
5
|
|
6
6
|
|
7
|
-
def jit(
|
7
|
+
def jit(
|
8
|
+
func=None,
|
9
|
+
*,
|
10
|
+
caller="torch",
|
11
|
+
kernel_name=None,
|
12
|
+
num_warps=None,
|
13
|
+
num_stages=None,
|
14
|
+
max_num_configs=None,
|
15
|
+
_prettify=False,
|
16
|
+
):
|
8
17
|
"""A decorator for generating compute kernels.
|
9
18
|
|
10
19
|
:param func: The function to be compiled.
|
11
20
|
:param caller: Who will call the compute kernel.
|
12
21
|
:param kernel_name: The name for the generated kernel.
|
22
|
+
:param num_warps: The number of warps to use.
|
23
|
+
:param num_stages: The number of pipeline stages.
|
24
|
+
:param max_num_configs: The maximum number of auto-tuning
|
25
|
+
configurations to use.
|
13
26
|
:param _prettify: Whether to prettify the generated code.
|
14
27
|
:return: A handle to the compute kernel.
|
15
28
|
|
@@ -20,7 +33,15 @@ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
|
|
20
33
|
"""
|
21
34
|
|
22
35
|
def wrapper(func):
|
23
|
-
return JIT(
|
36
|
+
return JIT(
|
37
|
+
func,
|
38
|
+
caller=caller,
|
39
|
+
kernel_name=kernel_name,
|
40
|
+
num_warps=num_warps,
|
41
|
+
num_stages=num_stages,
|
42
|
+
max_num_configs=max_num_configs,
|
43
|
+
_prettify=_prettify,
|
44
|
+
)()
|
24
45
|
|
25
46
|
if func is None:
|
26
47
|
return wrapper
|
@@ -29,7 +50,16 @@ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
|
|
29
50
|
|
30
51
|
|
31
52
|
class JIT:
|
32
|
-
def __init__(
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
func,
|
56
|
+
caller,
|
57
|
+
kernel_name,
|
58
|
+
num_warps,
|
59
|
+
num_stages,
|
60
|
+
max_num_configs,
|
61
|
+
_prettify=False,
|
62
|
+
):
|
33
63
|
self.func = func
|
34
64
|
|
35
65
|
self._caller = caller
|
@@ -39,12 +69,24 @@ class JIT:
|
|
39
69
|
else:
|
40
70
|
self._kernel_name = func.__name__
|
41
71
|
|
72
|
+
self._num_warps = num_warps
|
73
|
+
|
74
|
+
self._num_stages = num_stages
|
75
|
+
|
76
|
+
self._max_num_configs = max_num_configs
|
77
|
+
|
42
78
|
self._prettify = _prettify
|
43
79
|
|
44
80
|
def __call__(self):
|
45
81
|
code_generator = CodeGenerator()
|
46
82
|
source_file = code_generator(
|
47
|
-
self.func,
|
83
|
+
self.func,
|
84
|
+
self._caller,
|
85
|
+
self._kernel_name,
|
86
|
+
self._num_warps,
|
87
|
+
self._num_stages,
|
88
|
+
self._max_num_configs,
|
89
|
+
self._prettify,
|
48
90
|
)
|
49
91
|
module = type(self)._import_from_path(source_file, source_file)
|
50
92
|
module_vars = vars(module)
|
ninetoothed/make.py
CHANGED
@@ -2,6 +2,7 @@ import inspect
|
|
2
2
|
|
3
3
|
from ninetoothed.aot import aot
|
4
4
|
from ninetoothed.jit import jit
|
5
|
+
from ninetoothed.utils import calculate_default_configs
|
5
6
|
|
6
7
|
|
7
8
|
def make(
|
@@ -11,8 +12,9 @@ def make(
|
|
11
12
|
caller="torch",
|
12
13
|
kernel_name=None,
|
13
14
|
output_dir=None,
|
14
|
-
num_warps=
|
15
|
-
num_stages=
|
15
|
+
num_warps=None,
|
16
|
+
num_stages=None,
|
17
|
+
max_num_configs=None,
|
16
18
|
):
|
17
19
|
"""Integrate the arrangement and the application of the tensors.
|
18
20
|
|
@@ -24,16 +26,33 @@ def make(
|
|
24
26
|
:param output_dir: The directory to store the generated files.
|
25
27
|
:param num_warps: The number of warps to use.
|
26
28
|
:param num_stages: The number of pipeline stages.
|
29
|
+
:param max_num_configs: The maximum number of auto-tuning
|
30
|
+
configurations to use.
|
27
31
|
:return: A handle to the compute kernel.
|
28
32
|
"""
|
29
33
|
|
34
|
+
default_num_warps, default_num_stages = calculate_default_configs()
|
35
|
+
|
36
|
+
if num_warps is None:
|
37
|
+
num_warps = default_num_warps
|
38
|
+
|
39
|
+
if num_stages is None:
|
40
|
+
num_stages = default_num_stages
|
41
|
+
|
30
42
|
params = inspect.signature(application).parameters
|
31
43
|
types = arrangement(*tensors)
|
32
44
|
annotations = {param: type for param, type in zip(params, types)}
|
33
45
|
application.__annotations__ = annotations
|
34
46
|
|
35
47
|
if caller == "torch":
|
36
|
-
return jit(
|
48
|
+
return jit(
|
49
|
+
application,
|
50
|
+
caller=caller,
|
51
|
+
kernel_name=kernel_name,
|
52
|
+
num_warps=num_warps,
|
53
|
+
num_stages=num_stages,
|
54
|
+
max_num_configs=max_num_configs,
|
55
|
+
)
|
37
56
|
|
38
57
|
return aot(
|
39
58
|
application,
|
ninetoothed/symbol.py
CHANGED
@@ -12,9 +12,20 @@ class Symbol:
|
|
12
12
|
:param expr: The expression used to construct the symbol.
|
13
13
|
:param constexpr: Whether the symbol is a constexpr.
|
14
14
|
:param mata: Whether the symbol is a meta.
|
15
|
+
:param lower_bound: The minimum value for the symbol's range.
|
16
|
+
:param upper_bound: The maximum value for the symbol's range.
|
17
|
+
:param power_of_two: Whether the value should be a power of two.
|
15
18
|
"""
|
16
19
|
|
17
|
-
def __init__(
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
expr,
|
23
|
+
constexpr=None,
|
24
|
+
meta=None,
|
25
|
+
lower_bound=None,
|
26
|
+
upper_bound=None,
|
27
|
+
power_of_two=None,
|
28
|
+
):
|
18
29
|
if isinstance(expr, type(self)):
|
19
30
|
self._node = expr._node
|
20
31
|
return
|
@@ -43,6 +54,40 @@ class Symbol:
|
|
43
54
|
if constexpr:
|
44
55
|
self._node.id = naming.make_constexpr(self._node.id)
|
45
56
|
|
57
|
+
self._node.symbol = self
|
58
|
+
|
59
|
+
DEFAULT_LOWER_BOUND_FOR_META_SYMBOLS = 2**5
|
60
|
+
DEFAULT_UPPER_BOUND_FOR_META_SYMBOLS = 2**10
|
61
|
+
DEFAULT_POWER_OF_TWO_FOR_META_SYMBOLS = True
|
62
|
+
|
63
|
+
DEFAULT_LOWER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS = 1
|
64
|
+
DEFAULT_UPPER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS = 2**20
|
65
|
+
DEFAULT_POWER_OF_TWO_FOR_NON_META_CONSTEXPR_SYMBOLS = False
|
66
|
+
|
67
|
+
if lower_bound is not None:
|
68
|
+
self.lower_bound = lower_bound
|
69
|
+
else:
|
70
|
+
if meta:
|
71
|
+
self.lower_bound = DEFAULT_LOWER_BOUND_FOR_META_SYMBOLS
|
72
|
+
elif constexpr:
|
73
|
+
self.lower_bound = DEFAULT_LOWER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS
|
74
|
+
|
75
|
+
if upper_bound is not None:
|
76
|
+
self.upper_bound = upper_bound
|
77
|
+
else:
|
78
|
+
if meta:
|
79
|
+
self.upper_bound = DEFAULT_UPPER_BOUND_FOR_META_SYMBOLS
|
80
|
+
elif constexpr:
|
81
|
+
self.upper_bound = DEFAULT_UPPER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS
|
82
|
+
|
83
|
+
if power_of_two is not None:
|
84
|
+
self.power_of_two = power_of_two
|
85
|
+
else:
|
86
|
+
if meta:
|
87
|
+
self.power_of_two = DEFAULT_POWER_OF_TWO_FOR_META_SYMBOLS
|
88
|
+
elif constexpr:
|
89
|
+
self.power_of_two = DEFAULT_POWER_OF_TWO_FOR_NON_META_CONSTEXPR_SYMBOLS
|
90
|
+
|
46
91
|
def __eq__(self, other):
|
47
92
|
if isinstance(self._node, ast.Constant):
|
48
93
|
if isinstance(other, Symbol) and isinstance(other._node, ast.Constant):
|
@@ -155,7 +200,7 @@ class Symbol:
|
|
155
200
|
def visit_Name(self, node):
|
156
201
|
self.generic_visit(node)
|
157
202
|
|
158
|
-
self.names.add(node.
|
203
|
+
self.names.add(node.symbol)
|
159
204
|
|
160
205
|
name_collector = NameCollector()
|
161
206
|
|
@@ -179,6 +224,24 @@ class Symbol:
|
|
179
224
|
return isinstance(object, Symbol) and isinstance(object.node, ast.Name)
|
180
225
|
|
181
226
|
|
227
|
+
def block_size(lower_bound=None, upper_bound=None):
|
228
|
+
"""Create a block size symbol that serves as a meta-parameter.
|
229
|
+
|
230
|
+
:param lower_bound: The lower bound for the block size's range.
|
231
|
+
:param upper_bound: The upper bound for the block size's range.
|
232
|
+
:return: A block size symbol that serves as a meta-parameter.
|
233
|
+
"""
|
234
|
+
|
235
|
+
name = naming.auto_generate(f"BLOCK_SIZE_{block_size._num_block_sizes}")
|
236
|
+
|
237
|
+
block_size._num_block_sizes += 1
|
238
|
+
|
239
|
+
return Symbol(name, meta=True, lower_bound=lower_bound, upper_bound=upper_bound)
|
240
|
+
|
241
|
+
|
242
|
+
block_size._num_block_sizes = 0
|
243
|
+
|
244
|
+
|
182
245
|
class _FindAndReplacer(ast.NodeTransformer):
|
183
246
|
def __init__(self, targets, replacement):
|
184
247
|
self._targets_unparsed = tuple(
|
ninetoothed/tensor.py
CHANGED
@@ -14,7 +14,7 @@ class Tensor:
|
|
14
14
|
:param dtype: The element type of the tensor.
|
15
15
|
:param strides: The strides of the tensor.
|
16
16
|
:param other: The values for out-of-bounds positions.
|
17
|
-
:param
|
17
|
+
:param shape_options: The options for configuring shape symbols.
|
18
18
|
:param name: The name of the tensor.
|
19
19
|
:param source: For internal use only.
|
20
20
|
:param source_dims: For internal use only.
|
@@ -31,7 +31,7 @@ class Tensor:
|
|
31
31
|
dtype=None,
|
32
32
|
strides=None,
|
33
33
|
other=None,
|
34
|
-
|
34
|
+
shape_options=None,
|
35
35
|
name=None,
|
36
36
|
source=None,
|
37
37
|
source_dims=None,
|
@@ -48,9 +48,20 @@ class Tensor:
|
|
48
48
|
self.name = naming.auto_generate(f"tensor_{type(self).num_instances}")
|
49
49
|
|
50
50
|
if ndim is not None:
|
51
|
+
if shape_options is None:
|
52
|
+
shape_options = tuple({} for _ in range(ndim))
|
53
|
+
|
54
|
+
if isinstance(shape_options, dict):
|
55
|
+
shape_options = tuple(shape_options for _ in range(ndim))
|
56
|
+
|
57
|
+
shape_options = tuple(
|
58
|
+
size_options if size_options is not None else {}
|
59
|
+
for size_options in shape_options
|
60
|
+
)
|
61
|
+
|
51
62
|
self.shape = (
|
52
|
-
Symbol(self.size_string(i),
|
53
|
-
for i in range(ndim)
|
63
|
+
Symbol(self.size_string(i), **size_options)
|
64
|
+
for i, size_options in zip(range(ndim), shape_options)
|
54
65
|
)
|
55
66
|
self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
|
56
67
|
else:
|
@@ -146,7 +157,7 @@ class Tensor:
|
|
146
157
|
)
|
147
158
|
outer_shape.append(new_size)
|
148
159
|
|
149
|
-
new_stride = self_stride * stride
|
160
|
+
new_stride = self_stride * stride
|
150
161
|
outer_strides.append(new_stride)
|
151
162
|
|
152
163
|
inner_shape.append(tile_size)
|
@@ -364,10 +375,10 @@ class Tensor:
|
|
364
375
|
|
365
376
|
def names(self):
|
366
377
|
if self.ndim == 0:
|
367
|
-
return {self.source.name}
|
378
|
+
return {Symbol(self.source.name)}
|
368
379
|
|
369
380
|
return (
|
370
|
-
{self.source.pointer_string()}
|
381
|
+
{Symbol(self.source.pointer_string())}
|
371
382
|
| {
|
372
383
|
name
|
373
384
|
for value in itertools.chain(self.shape, self.strides)
|
ninetoothed/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
1
|
+
import triton
|
2
|
+
|
3
|
+
|
4
|
+
def calculate_default_configs():
|
5
|
+
device = triton.runtime.driver.active.get_current_device()
|
6
|
+
properties = triton.runtime.driver.active.utils.get_device_properties(device)
|
7
|
+
max_shared_mem = properties["max_shared_mem"]
|
8
|
+
|
9
|
+
num_warps = 8
|
10
|
+
num_stages = max_shared_mem // 2**15
|
11
|
+
|
12
|
+
return num_warps, num_stages
|
ninetoothed/visualization.py
CHANGED
@@ -118,10 +118,16 @@ def _visualize_unit_square(ax, x, y, color):
|
|
118
118
|
|
119
119
|
|
120
120
|
def _visualize_rect(ax, width, height, x, y, color):
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
121
|
+
ax.add_patch(
|
122
|
+
plt.Rectangle(
|
123
|
+
(x, y),
|
124
|
+
width,
|
125
|
+
height,
|
126
|
+
edgecolor="k",
|
127
|
+
facecolor=color,
|
128
|
+
linewidth=plt.rcParams["lines.linewidth"],
|
129
|
+
)
|
130
|
+
)
|
125
131
|
|
126
132
|
|
127
133
|
def _verts_of_rect(width, height, x, y):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.16.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
|
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
|
+
Requires-Dist: sympy>=1.13.0
|
13
14
|
Requires-Dist: triton>=3.0.0
|
14
15
|
Provides-Extra: all
|
15
16
|
Requires-Dist: matplotlib>=3.9.0; extra == 'all'
|
@@ -0,0 +1,18 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=F2bxRNhzcGdtADA8RehTuf-QK0xnxno8kxvr6H2L5Tg,552
|
2
|
+
ninetoothed/aot.py,sha256=8ZCLtnsign14YvY7SXX5ASidhuUAhPwppTXUJNkQup4,6243
|
3
|
+
ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
|
4
|
+
ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
|
5
|
+
ninetoothed/generation.py,sha256=VIqSyZT4yHxY_a2QPmWW6jjALv3e1mohDqdRQBRYsAo,36462
|
6
|
+
ninetoothed/jit.py,sha256=CpeSkO_zUe9DwtTJ2K2H7Bwpx-FvIHfrgzOcEosfpek,2946
|
7
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
8
|
+
ninetoothed/make.py,sha256=fQKuRJL7HC2iGTAN323mlIWXz9Z3jotIoN68ur29Qlw,1834
|
9
|
+
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
10
|
+
ninetoothed/symbol.py,sha256=lJo3NL2-T7tKbKjb6MCRLMemN94mqS3bIiG943P0Mbo,7454
|
11
|
+
ninetoothed/tensor.py,sha256=gQEzHTcXqZVBFLc2YRfXTKxjxPWMxWN7fNl2BCfJwMs,14782
|
12
|
+
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
13
|
+
ninetoothed/utils.py,sha256=mtRXABBVPnlgd2n1REh9oB3s_5bUsKhd3iwu3oJ5DSQ,338
|
14
|
+
ninetoothed/visualization.py,sha256=zlMH-0WplaboePGzcbpcj4UovpX0k2r4SysSPsNS4r4,3674
|
15
|
+
ninetoothed-0.16.0.dist-info/METADATA,sha256=nkq3iImebtmcEs-bZq2zfF2_QxrZD9IWky1S86OnUMA,7340
|
16
|
+
ninetoothed-0.16.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
ninetoothed-0.16.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
18
|
+
ninetoothed-0.16.0.dist-info/RECORD,,
|
@@ -1,17 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=zGaZiUzwJZ2jfwLxp7lT8ll_V5ngP5QYrfVbapftbCY,522
|
2
|
-
ninetoothed/aot.py,sha256=1hzl4-6MqscB4tDMqmLCOTlyzsYkbY20EnmDgHO8hU4,6137
|
3
|
-
ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
|
4
|
-
ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
|
5
|
-
ninetoothed/generation.py,sha256=QHrK7DOuJo5wEV-5HAqu_e-suuD4TPPNmCzrUMRYF2w,30940
|
6
|
-
ninetoothed/jit.py,sha256=0MFbFIODtw-bxuOC7WByxiVtQMeyvZkoDxvfAZ9rIFQ,2120
|
7
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
8
|
-
ninetoothed/make.py,sha256=wRr3JwGt5E2OCquq_nzBZljdW-AJPOqH49cM08gwl4A,1287
|
9
|
-
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
10
|
-
ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
|
11
|
-
ninetoothed/tensor.py,sha256=W1XY8_vaYmszX4lIWuas-ZKGbbdEZU7Z5h1A4FBXDXg,14358
|
12
|
-
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
13
|
-
ninetoothed/visualization.py,sha256=IZ7iTT4dl5_JFbO-WfSWPFWpgkyPr4nylwhSZVy8gss,3601
|
14
|
-
ninetoothed-0.15.0.dist-info/METADATA,sha256=UmND-TBDf7vrdii9dhiOZmiZBzrcG8-xEniThTSikqM,7311
|
15
|
-
ninetoothed-0.15.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
-
ninetoothed-0.15.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
17
|
-
ninetoothed-0.15.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|