ninetoothed 0.15.1__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/__init__.py +2 -1
- ninetoothed/aot.py +7 -1
- ninetoothed/generation.py +238 -27
- ninetoothed/jit.py +46 -4
- ninetoothed/make.py +22 -3
- ninetoothed/symbol.py +65 -2
- ninetoothed/tensor.py +17 -6
- ninetoothed/utils.py +12 -0
- {ninetoothed-0.15.1.dist-info → ninetoothed-0.16.0.dist-info}/METADATA +2 -1
- ninetoothed-0.16.0.dist-info/RECORD +18 -0
- ninetoothed-0.15.1.dist-info/RECORD +0 -17
- {ninetoothed-0.15.1.dist-info → ninetoothed-0.16.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.15.1.dist-info → ninetoothed-0.16.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/__init__.py
CHANGED
@@ -13,12 +13,13 @@ from ninetoothed.dtype import (
|
|
13
13
|
)
|
14
14
|
from ninetoothed.jit import jit
|
15
15
|
from ninetoothed.make import make
|
16
|
-
from ninetoothed.symbol import Symbol
|
16
|
+
from ninetoothed.symbol import Symbol, block_size
|
17
17
|
from ninetoothed.tensor import Tensor
|
18
18
|
|
19
19
|
__all__ = [
|
20
20
|
"Symbol",
|
21
21
|
"Tensor",
|
22
|
+
"block_size",
|
22
23
|
"float16",
|
23
24
|
"float32",
|
24
25
|
"float64",
|
ninetoothed/aot.py
CHANGED
@@ -36,7 +36,13 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
36
36
|
|
37
37
|
code_generator = CodeGenerator()
|
38
38
|
source_file = code_generator(
|
39
|
-
func,
|
39
|
+
func,
|
40
|
+
caller=caller,
|
41
|
+
kernel_name=kernel_name,
|
42
|
+
num_warps=num_warps,
|
43
|
+
num_stages=num_stages,
|
44
|
+
max_num_configs=None,
|
45
|
+
prettify=False,
|
40
46
|
)
|
41
47
|
|
42
48
|
tensors = code_generator.tensors
|
ninetoothed/generation.py
CHANGED
@@ -5,11 +5,20 @@ import functools
|
|
5
5
|
import hashlib
|
6
6
|
import inspect
|
7
7
|
import itertools
|
8
|
+
import json
|
8
9
|
import math
|
10
|
+
import os
|
9
11
|
import pathlib
|
12
|
+
import random
|
13
|
+
import shutil
|
10
14
|
import subprocess
|
15
|
+
import tempfile
|
16
|
+
import time
|
17
|
+
import uuid
|
11
18
|
|
19
|
+
import sympy
|
12
20
|
import triton
|
21
|
+
import triton.language as tl
|
13
22
|
|
14
23
|
import ninetoothed.naming as naming
|
15
24
|
from ninetoothed.cudaifier import Cudaifier
|
@@ -19,19 +28,47 @@ from ninetoothed.tensor import Tensor
|
|
19
28
|
from ninetoothed.torchifier import Torchifier
|
20
29
|
|
21
30
|
CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
|
31
|
+
CACHE_DIR.mkdir(exist_ok=True)
|
22
32
|
|
23
33
|
|
24
34
|
class CodeGenerator(ast.NodeTransformer):
|
25
35
|
def __init__(self):
|
26
36
|
super().__init__()
|
27
37
|
|
28
|
-
|
38
|
+
cache_file = CACHE_DIR / "code_generator_cache.json"
|
29
39
|
|
30
|
-
|
40
|
+
log2_min_num_elements = 4
|
31
41
|
|
32
|
-
|
42
|
+
if cache_file.exists():
|
43
|
+
with open(cache_file) as f:
|
44
|
+
cache = json.load(f)
|
33
45
|
|
34
|
-
|
46
|
+
log2_max_num_elements = cache["log2_max_num_elements"]
|
47
|
+
else:
|
48
|
+
log2_max_num_elements = _determine_log2_max_num_elements_per_block(
|
49
|
+
log2_min_num_elements
|
50
|
+
)
|
51
|
+
|
52
|
+
cache = {"log2_max_num_elements": log2_max_num_elements}
|
53
|
+
|
54
|
+
with open(cache_file, "w") as f:
|
55
|
+
json.dump(cache, f, indent=4)
|
56
|
+
f.write("\n")
|
57
|
+
|
58
|
+
self._min_num_elements = 2**log2_min_num_elements
|
59
|
+
|
60
|
+
self._max_num_elements = 2**log2_max_num_elements
|
61
|
+
|
62
|
+
def __call__(
|
63
|
+
self,
|
64
|
+
func,
|
65
|
+
caller,
|
66
|
+
kernel_name,
|
67
|
+
num_warps,
|
68
|
+
num_stages,
|
69
|
+
max_num_configs,
|
70
|
+
prettify,
|
71
|
+
):
|
35
72
|
def _get_tree(func):
|
36
73
|
module = ast.parse(inspect.getsource(inspect.getmodule(func)))
|
37
74
|
|
@@ -63,6 +100,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
63
100
|
|
64
101
|
self._caller = caller
|
65
102
|
|
103
|
+
self._num_wraps = num_warps
|
104
|
+
|
105
|
+
self._num_stages = num_stages
|
106
|
+
|
107
|
+
self._max_num_configs = max_num_configs
|
108
|
+
|
66
109
|
self._context = inspect.get_annotations(func)
|
67
110
|
|
68
111
|
self._args = list(self._context.values())
|
@@ -94,9 +137,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
94
137
|
)
|
95
138
|
|
96
139
|
digest = hashlib.sha256(source.encode("utf-8")).hexdigest()
|
97
|
-
|
98
|
-
cache_dir.mkdir(exist_ok=True)
|
99
|
-
cache_file = cache_dir / f"{digest}.py"
|
140
|
+
cache_file = CACHE_DIR / f"{digest}.py"
|
100
141
|
|
101
142
|
if not cache_file.exists():
|
102
143
|
with open(cache_file, "w", encoding="utf-8") as f:
|
@@ -111,11 +152,15 @@ class CodeGenerator(ast.NodeTransformer):
|
|
111
152
|
def visit_Module(self, node):
|
112
153
|
self.generic_visit(node)
|
113
154
|
|
114
|
-
|
155
|
+
if self._autotune is not None:
|
156
|
+
func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
|
157
|
+
|
158
|
+
node.body.append(
|
159
|
+
ast.parse(
|
160
|
+
f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}"
|
161
|
+
)
|
162
|
+
)
|
115
163
|
|
116
|
-
node.body.append(
|
117
|
-
ast.parse(f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}")
|
118
|
-
)
|
119
164
|
node.body.append(self._launch)
|
120
165
|
|
121
166
|
return node
|
@@ -137,8 +182,13 @@ class CodeGenerator(ast.NodeTransformer):
|
|
137
182
|
def visit_arguments(self, node):
|
138
183
|
self.generic_visit(node)
|
139
184
|
|
140
|
-
|
141
|
-
|
185
|
+
symbols = {
|
186
|
+
name.node.id: name
|
187
|
+
for arg in self._args
|
188
|
+
for name in arg.names()
|
189
|
+
if name != "ninetoothed"
|
190
|
+
}
|
191
|
+
names = symbols.keys()
|
142
192
|
meta_names = {name for name in names if naming.is_meta(name)}
|
143
193
|
non_meta_names = {name for name in names if name not in meta_names}
|
144
194
|
non_meta_names |= {
|
@@ -147,6 +197,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
147
197
|
if naming.is_constexpr(name)
|
148
198
|
}
|
149
199
|
|
200
|
+
self._symbols = symbols
|
201
|
+
|
150
202
|
non_meta_names = sorted(non_meta_names)
|
151
203
|
meta_names = sorted(meta_names)
|
152
204
|
|
@@ -161,6 +213,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
161
213
|
]
|
162
214
|
|
163
215
|
self._autotune = self._generate_autotune(non_meta_names, meta_names)
|
216
|
+
|
217
|
+
if self._autotune is not None:
|
218
|
+
self._func_name = self._func_name_with_auto_tuning
|
219
|
+
else:
|
220
|
+
self._func_name = self._func_def.name
|
221
|
+
|
164
222
|
self._func_def.decorator_list = [Symbol("triton.jit").node]
|
165
223
|
|
166
224
|
self._launch = self._generate_launch(non_meta_names, meta_names)
|
@@ -244,12 +302,69 @@ class CodeGenerator(ast.NodeTransformer):
|
|
244
302
|
return isinstance(node, ast.Name) and node.id in self._context
|
245
303
|
|
246
304
|
def _generate_autotune(self, params, meta):
|
247
|
-
|
248
|
-
|
249
|
-
|
305
|
+
inequalities = True
|
306
|
+
|
307
|
+
for arg in self._args:
|
308
|
+
if arg.ndim == 0:
|
309
|
+
continue
|
310
|
+
|
311
|
+
num_elements = sympy.simplify(str(math.prod(arg.innermost().shape)))
|
312
|
+
|
313
|
+
inequalities &= num_elements <= self._max_num_elements
|
314
|
+
inequalities &= num_elements >= self._min_num_elements
|
315
|
+
|
316
|
+
values_of_meta_params = []
|
317
|
+
|
318
|
+
for param in meta:
|
319
|
+
symbol = self._symbols[param]
|
320
|
+
|
321
|
+
values = range(symbol.lower_bound, symbol.upper_bound + 1)
|
322
|
+
|
323
|
+
if symbol.power_of_two:
|
324
|
+
values = tuple(value for value in values if value & (value - 1) == 0)
|
325
|
+
else:
|
326
|
+
values = tuple(values)
|
327
|
+
|
328
|
+
values_of_meta_params.append(values)
|
329
|
+
|
330
|
+
max_values_of_non_meta_params = {}
|
331
|
+
|
332
|
+
for free_symbol in inequalities.free_symbols:
|
333
|
+
symbol_str = str(free_symbol)
|
334
|
+
|
335
|
+
if symbol_str in meta:
|
336
|
+
continue
|
337
|
+
|
338
|
+
symbol = self._symbols[symbol_str]
|
339
|
+
|
340
|
+
max_values_of_non_meta_params[symbol_str] = symbol.upper_bound
|
341
|
+
|
342
|
+
block_size_configs = []
|
343
|
+
|
344
|
+
for values in itertools.product(*values_of_meta_params):
|
345
|
+
config = {param: value for param, value in zip(meta, values)}
|
346
|
+
|
347
|
+
if sympy.logic.simplify_logic(
|
348
|
+
inequalities.subs(config | max_values_of_non_meta_params)
|
349
|
+
):
|
350
|
+
block_size_configs.append(config)
|
351
|
+
|
352
|
+
if isinstance(self._num_wraps, collections.abc.Iterable):
|
353
|
+
num_warps_configs = self._num_wraps
|
354
|
+
else:
|
355
|
+
num_warps_configs = (self._num_wraps,)
|
356
|
+
|
357
|
+
if isinstance(self._num_stages, collections.abc.Iterable):
|
358
|
+
num_stages_configs = self._num_stages
|
359
|
+
else:
|
360
|
+
num_stages_configs = (self._num_stages,)
|
250
361
|
|
251
|
-
|
252
|
-
|
362
|
+
compiler_configs = tuple(
|
363
|
+
{"num_warps": num_warps, "num_stages": num_stages}
|
364
|
+
for num_warps, num_stages in itertools.product(
|
365
|
+
num_warps_configs, num_stages_configs
|
366
|
+
)
|
367
|
+
)
|
253
368
|
|
254
369
|
configs = [
|
255
370
|
ast.Call(
|
@@ -260,19 +375,38 @@ class CodeGenerator(ast.NodeTransformer):
|
|
260
375
|
),
|
261
376
|
args=[
|
262
377
|
ast.Dict(
|
263
|
-
keys=[
|
264
|
-
|
378
|
+
keys=[
|
379
|
+
ast.Constant(value=param)
|
380
|
+
for param in block_size_config.keys()
|
381
|
+
],
|
382
|
+
values=[
|
383
|
+
ast.Constant(value=value)
|
384
|
+
for value in block_size_config.values()
|
385
|
+
],
|
265
386
|
)
|
266
387
|
],
|
267
388
|
keywords=[
|
268
|
-
ast.keyword(
|
269
|
-
|
389
|
+
ast.keyword(
|
390
|
+
arg="num_warps",
|
391
|
+
value=ast.Constant(value=compiler_config["num_warps"]),
|
392
|
+
),
|
393
|
+
ast.keyword(
|
394
|
+
arg="num_stages",
|
395
|
+
value=ast.Constant(value=compiler_config["num_stages"]),
|
396
|
+
),
|
270
397
|
],
|
271
398
|
)
|
272
|
-
for
|
273
|
-
|
399
|
+
for block_size_config, compiler_config in itertools.product(
|
400
|
+
block_size_configs, compiler_configs
|
401
|
+
)
|
274
402
|
]
|
275
403
|
|
404
|
+
if self._max_num_configs is not None and len(configs) > self._max_num_configs:
|
405
|
+
configs = random.sample(configs, k=self._max_num_configs)
|
406
|
+
|
407
|
+
if not configs:
|
408
|
+
return None
|
409
|
+
|
276
410
|
return ast.Call(
|
277
411
|
func=ast.Attribute(
|
278
412
|
value=ast.Name(id="ninetoothed", ctx=ast.Load()),
|
@@ -358,9 +492,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
358
492
|
ast.Expr(
|
359
493
|
ast.Call(
|
360
494
|
func=ast.Subscript(
|
361
|
-
value=ast.Name(
|
362
|
-
id=self._func_name_with_auto_tuning, ctx=ast.Load()
|
363
|
-
),
|
495
|
+
value=ast.Name(id=self._func_name, ctx=ast.Load()),
|
364
496
|
slice=self._generate_grid(),
|
365
497
|
ctx=ast.Load(),
|
366
498
|
),
|
@@ -994,3 +1126,82 @@ class _FunctionDefFinder(ast.NodeVisitor):
|
|
994
1126
|
self.result = node
|
995
1127
|
|
996
1128
|
self.generic_visit(node)
|
1129
|
+
|
1130
|
+
|
1131
|
+
def _determine_log2_max_num_elements_per_block(
|
1132
|
+
min_exponent, max_exponent=30, num_iterations=3
|
1133
|
+
):
|
1134
|
+
_profile_pseudo_add_kernel(1)
|
1135
|
+
|
1136
|
+
for n in range(min_exponent, max_exponent + 1):
|
1137
|
+
elapsed_time = 0
|
1138
|
+
|
1139
|
+
for _ in range(num_iterations):
|
1140
|
+
elapsed_time += _profile_pseudo_add_kernel(2**n)
|
1141
|
+
|
1142
|
+
average_elapsed_time = elapsed_time / num_iterations
|
1143
|
+
|
1144
|
+
if average_elapsed_time >= 1:
|
1145
|
+
return n - 1
|
1146
|
+
|
1147
|
+
|
1148
|
+
def _profile_pseudo_add_kernel(block_size):
|
1149
|
+
cache_dir = triton.runtime.cache.default_cache_dir()
|
1150
|
+
os.makedirs(cache_dir, exist_ok=True)
|
1151
|
+
|
1152
|
+
with tempfile.TemporaryDirectory() as backup_dir:
|
1153
|
+
backup_path = os.path.join(backup_dir, str(uuid.uuid4()))
|
1154
|
+
|
1155
|
+
if os.path.exists(backup_path):
|
1156
|
+
shutil.rmtree(backup_path)
|
1157
|
+
|
1158
|
+
shutil.move(cache_dir, backup_path)
|
1159
|
+
|
1160
|
+
try:
|
1161
|
+
start_time = time.time()
|
1162
|
+
|
1163
|
+
_run_pseudo_add_kernel(block_size)
|
1164
|
+
|
1165
|
+
end_time = time.time()
|
1166
|
+
|
1167
|
+
elapsed_time = end_time - start_time
|
1168
|
+
finally:
|
1169
|
+
if os.path.exists(cache_dir):
|
1170
|
+
shutil.rmtree(cache_dir)
|
1171
|
+
|
1172
|
+
shutil.move(backup_path, cache_dir)
|
1173
|
+
|
1174
|
+
return elapsed_time
|
1175
|
+
|
1176
|
+
|
1177
|
+
def _run_pseudo_add_kernel(block_size):
|
1178
|
+
@triton.jit
|
1179
|
+
def kernel(a_ptr, b_ptr, c_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
|
1180
|
+
pid = tl.program_id(0)
|
1181
|
+
|
1182
|
+
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
1183
|
+
mask = offs < num_elements
|
1184
|
+
|
1185
|
+
a = tl.load(a_ptr + offs, mask=mask)
|
1186
|
+
b = tl.load(b_ptr + offs, mask=mask)
|
1187
|
+
|
1188
|
+
c = a + b
|
1189
|
+
|
1190
|
+
tl.store(c_ptr + offs, c, mask=mask)
|
1191
|
+
|
1192
|
+
num_elements = 0
|
1193
|
+
shape = (num_elements,)
|
1194
|
+
dtype = tl.float32
|
1195
|
+
|
1196
|
+
a = Tensor(shape=shape, dtype=dtype)
|
1197
|
+
b = Tensor(shape=shape, dtype=dtype)
|
1198
|
+
c = Tensor(shape=shape, dtype=dtype)
|
1199
|
+
|
1200
|
+
def data_ptr():
|
1201
|
+
return 0
|
1202
|
+
|
1203
|
+
a.data_ptr = data_ptr
|
1204
|
+
b.data_ptr = data_ptr
|
1205
|
+
c.data_ptr = data_ptr
|
1206
|
+
|
1207
|
+
kernel[(1,)](a, b, c, num_elements, block_size)
|
ninetoothed/jit.py
CHANGED
@@ -4,12 +4,25 @@ import sys
|
|
4
4
|
from ninetoothed.generation import CodeGenerator
|
5
5
|
|
6
6
|
|
7
|
-
def jit(
|
7
|
+
def jit(
|
8
|
+
func=None,
|
9
|
+
*,
|
10
|
+
caller="torch",
|
11
|
+
kernel_name=None,
|
12
|
+
num_warps=None,
|
13
|
+
num_stages=None,
|
14
|
+
max_num_configs=None,
|
15
|
+
_prettify=False,
|
16
|
+
):
|
8
17
|
"""A decorator for generating compute kernels.
|
9
18
|
|
10
19
|
:param func: The function to be compiled.
|
11
20
|
:param caller: Who will call the compute kernel.
|
12
21
|
:param kernel_name: The name for the generated kernel.
|
22
|
+
:param num_warps: The number of warps to use.
|
23
|
+
:param num_stages: The number of pipeline stages.
|
24
|
+
:param max_num_configs: The maximum number of auto-tuning
|
25
|
+
configurations to use.
|
13
26
|
:param _prettify: Whether to prettify the generated code.
|
14
27
|
:return: A handle to the compute kernel.
|
15
28
|
|
@@ -20,7 +33,15 @@ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
|
|
20
33
|
"""
|
21
34
|
|
22
35
|
def wrapper(func):
|
23
|
-
return JIT(
|
36
|
+
return JIT(
|
37
|
+
func,
|
38
|
+
caller=caller,
|
39
|
+
kernel_name=kernel_name,
|
40
|
+
num_warps=num_warps,
|
41
|
+
num_stages=num_stages,
|
42
|
+
max_num_configs=max_num_configs,
|
43
|
+
_prettify=_prettify,
|
44
|
+
)()
|
24
45
|
|
25
46
|
if func is None:
|
26
47
|
return wrapper
|
@@ -29,7 +50,16 @@ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
|
|
29
50
|
|
30
51
|
|
31
52
|
class JIT:
|
32
|
-
def __init__(
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
func,
|
56
|
+
caller,
|
57
|
+
kernel_name,
|
58
|
+
num_warps,
|
59
|
+
num_stages,
|
60
|
+
max_num_configs,
|
61
|
+
_prettify=False,
|
62
|
+
):
|
33
63
|
self.func = func
|
34
64
|
|
35
65
|
self._caller = caller
|
@@ -39,12 +69,24 @@ class JIT:
|
|
39
69
|
else:
|
40
70
|
self._kernel_name = func.__name__
|
41
71
|
|
72
|
+
self._num_warps = num_warps
|
73
|
+
|
74
|
+
self._num_stages = num_stages
|
75
|
+
|
76
|
+
self._max_num_configs = max_num_configs
|
77
|
+
|
42
78
|
self._prettify = _prettify
|
43
79
|
|
44
80
|
def __call__(self):
|
45
81
|
code_generator = CodeGenerator()
|
46
82
|
source_file = code_generator(
|
47
|
-
self.func,
|
83
|
+
self.func,
|
84
|
+
self._caller,
|
85
|
+
self._kernel_name,
|
86
|
+
self._num_warps,
|
87
|
+
self._num_stages,
|
88
|
+
self._max_num_configs,
|
89
|
+
self._prettify,
|
48
90
|
)
|
49
91
|
module = type(self)._import_from_path(source_file, source_file)
|
50
92
|
module_vars = vars(module)
|
ninetoothed/make.py
CHANGED
@@ -2,6 +2,7 @@ import inspect
|
|
2
2
|
|
3
3
|
from ninetoothed.aot import aot
|
4
4
|
from ninetoothed.jit import jit
|
5
|
+
from ninetoothed.utils import calculate_default_configs
|
5
6
|
|
6
7
|
|
7
8
|
def make(
|
@@ -11,8 +12,9 @@ def make(
|
|
11
12
|
caller="torch",
|
12
13
|
kernel_name=None,
|
13
14
|
output_dir=None,
|
14
|
-
num_warps=
|
15
|
-
num_stages=
|
15
|
+
num_warps=None,
|
16
|
+
num_stages=None,
|
17
|
+
max_num_configs=None,
|
16
18
|
):
|
17
19
|
"""Integrate the arrangement and the application of the tensors.
|
18
20
|
|
@@ -24,16 +26,33 @@ def make(
|
|
24
26
|
:param output_dir: The directory to store the generated files.
|
25
27
|
:param num_warps: The number of warps to use.
|
26
28
|
:param num_stages: The number of pipeline stages.
|
29
|
+
:param max_num_configs: The maximum number of auto-tuning
|
30
|
+
configurations to use.
|
27
31
|
:return: A handle to the compute kernel.
|
28
32
|
"""
|
29
33
|
|
34
|
+
default_num_warps, default_num_stages = calculate_default_configs()
|
35
|
+
|
36
|
+
if num_warps is None:
|
37
|
+
num_warps = default_num_warps
|
38
|
+
|
39
|
+
if num_stages is None:
|
40
|
+
num_stages = default_num_stages
|
41
|
+
|
30
42
|
params = inspect.signature(application).parameters
|
31
43
|
types = arrangement(*tensors)
|
32
44
|
annotations = {param: type for param, type in zip(params, types)}
|
33
45
|
application.__annotations__ = annotations
|
34
46
|
|
35
47
|
if caller == "torch":
|
36
|
-
return jit(
|
48
|
+
return jit(
|
49
|
+
application,
|
50
|
+
caller=caller,
|
51
|
+
kernel_name=kernel_name,
|
52
|
+
num_warps=num_warps,
|
53
|
+
num_stages=num_stages,
|
54
|
+
max_num_configs=max_num_configs,
|
55
|
+
)
|
37
56
|
|
38
57
|
return aot(
|
39
58
|
application,
|
ninetoothed/symbol.py
CHANGED
@@ -12,9 +12,20 @@ class Symbol:
|
|
12
12
|
:param expr: The expression used to construct the symbol.
|
13
13
|
:param constexpr: Whether the symbol is a constexpr.
|
14
14
|
:param mata: Whether the symbol is a meta.
|
15
|
+
:param lower_bound: The minimum value for the symbol's range.
|
16
|
+
:param upper_bound: The maximum value for the symbol's range.
|
17
|
+
:param power_of_two: Whether the value should be a power of two.
|
15
18
|
"""
|
16
19
|
|
17
|
-
def __init__(
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
expr,
|
23
|
+
constexpr=None,
|
24
|
+
meta=None,
|
25
|
+
lower_bound=None,
|
26
|
+
upper_bound=None,
|
27
|
+
power_of_two=None,
|
28
|
+
):
|
18
29
|
if isinstance(expr, type(self)):
|
19
30
|
self._node = expr._node
|
20
31
|
return
|
@@ -43,6 +54,40 @@ class Symbol:
|
|
43
54
|
if constexpr:
|
44
55
|
self._node.id = naming.make_constexpr(self._node.id)
|
45
56
|
|
57
|
+
self._node.symbol = self
|
58
|
+
|
59
|
+
DEFAULT_LOWER_BOUND_FOR_META_SYMBOLS = 2**5
|
60
|
+
DEFAULT_UPPER_BOUND_FOR_META_SYMBOLS = 2**10
|
61
|
+
DEFAULT_POWER_OF_TWO_FOR_META_SYMBOLS = True
|
62
|
+
|
63
|
+
DEFAULT_LOWER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS = 1
|
64
|
+
DEFAULT_UPPER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS = 2**20
|
65
|
+
DEFAULT_POWER_OF_TWO_FOR_NON_META_CONSTEXPR_SYMBOLS = False
|
66
|
+
|
67
|
+
if lower_bound is not None:
|
68
|
+
self.lower_bound = lower_bound
|
69
|
+
else:
|
70
|
+
if meta:
|
71
|
+
self.lower_bound = DEFAULT_LOWER_BOUND_FOR_META_SYMBOLS
|
72
|
+
elif constexpr:
|
73
|
+
self.lower_bound = DEFAULT_LOWER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS
|
74
|
+
|
75
|
+
if upper_bound is not None:
|
76
|
+
self.upper_bound = upper_bound
|
77
|
+
else:
|
78
|
+
if meta:
|
79
|
+
self.upper_bound = DEFAULT_UPPER_BOUND_FOR_META_SYMBOLS
|
80
|
+
elif constexpr:
|
81
|
+
self.upper_bound = DEFAULT_UPPER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS
|
82
|
+
|
83
|
+
if power_of_two is not None:
|
84
|
+
self.power_of_two = power_of_two
|
85
|
+
else:
|
86
|
+
if meta:
|
87
|
+
self.power_of_two = DEFAULT_POWER_OF_TWO_FOR_META_SYMBOLS
|
88
|
+
elif constexpr:
|
89
|
+
self.power_of_two = DEFAULT_POWER_OF_TWO_FOR_NON_META_CONSTEXPR_SYMBOLS
|
90
|
+
|
46
91
|
def __eq__(self, other):
|
47
92
|
if isinstance(self._node, ast.Constant):
|
48
93
|
if isinstance(other, Symbol) and isinstance(other._node, ast.Constant):
|
@@ -155,7 +200,7 @@ class Symbol:
|
|
155
200
|
def visit_Name(self, node):
|
156
201
|
self.generic_visit(node)
|
157
202
|
|
158
|
-
self.names.add(node.
|
203
|
+
self.names.add(node.symbol)
|
159
204
|
|
160
205
|
name_collector = NameCollector()
|
161
206
|
|
@@ -179,6 +224,24 @@ class Symbol:
|
|
179
224
|
return isinstance(object, Symbol) and isinstance(object.node, ast.Name)
|
180
225
|
|
181
226
|
|
227
|
+
def block_size(lower_bound=None, upper_bound=None):
|
228
|
+
"""Create a block size symbol that serves as a meta-parameter.
|
229
|
+
|
230
|
+
:param lower_bound: The lower bound for the block size's range.
|
231
|
+
:param upper_bound: The upper bound for the block size's range.
|
232
|
+
:return: A block size symbol that serves as a meta-parameter.
|
233
|
+
"""
|
234
|
+
|
235
|
+
name = naming.auto_generate(f"BLOCK_SIZE_{block_size._num_block_sizes}")
|
236
|
+
|
237
|
+
block_size._num_block_sizes += 1
|
238
|
+
|
239
|
+
return Symbol(name, meta=True, lower_bound=lower_bound, upper_bound=upper_bound)
|
240
|
+
|
241
|
+
|
242
|
+
block_size._num_block_sizes = 0
|
243
|
+
|
244
|
+
|
182
245
|
class _FindAndReplacer(ast.NodeTransformer):
|
183
246
|
def __init__(self, targets, replacement):
|
184
247
|
self._targets_unparsed = tuple(
|
ninetoothed/tensor.py
CHANGED
@@ -14,7 +14,7 @@ class Tensor:
|
|
14
14
|
:param dtype: The element type of the tensor.
|
15
15
|
:param strides: The strides of the tensor.
|
16
16
|
:param other: The values for out-of-bounds positions.
|
17
|
-
:param
|
17
|
+
:param shape_options: The options for configuring shape symbols.
|
18
18
|
:param name: The name of the tensor.
|
19
19
|
:param source: For internal use only.
|
20
20
|
:param source_dims: For internal use only.
|
@@ -31,7 +31,7 @@ class Tensor:
|
|
31
31
|
dtype=None,
|
32
32
|
strides=None,
|
33
33
|
other=None,
|
34
|
-
|
34
|
+
shape_options=None,
|
35
35
|
name=None,
|
36
36
|
source=None,
|
37
37
|
source_dims=None,
|
@@ -48,9 +48,20 @@ class Tensor:
|
|
48
48
|
self.name = naming.auto_generate(f"tensor_{type(self).num_instances}")
|
49
49
|
|
50
50
|
if ndim is not None:
|
51
|
+
if shape_options is None:
|
52
|
+
shape_options = tuple({} for _ in range(ndim))
|
53
|
+
|
54
|
+
if isinstance(shape_options, dict):
|
55
|
+
shape_options = tuple(shape_options for _ in range(ndim))
|
56
|
+
|
57
|
+
shape_options = tuple(
|
58
|
+
size_options if size_options is not None else {}
|
59
|
+
for size_options in shape_options
|
60
|
+
)
|
61
|
+
|
51
62
|
self.shape = (
|
52
|
-
Symbol(self.size_string(i),
|
53
|
-
for i in range(ndim)
|
63
|
+
Symbol(self.size_string(i), **size_options)
|
64
|
+
for i, size_options in zip(range(ndim), shape_options)
|
54
65
|
)
|
55
66
|
self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
|
56
67
|
else:
|
@@ -364,10 +375,10 @@ class Tensor:
|
|
364
375
|
|
365
376
|
def names(self):
|
366
377
|
if self.ndim == 0:
|
367
|
-
return {self.source.name}
|
378
|
+
return {Symbol(self.source.name)}
|
368
379
|
|
369
380
|
return (
|
370
|
-
{self.source.pointer_string()}
|
381
|
+
{Symbol(self.source.pointer_string())}
|
371
382
|
| {
|
372
383
|
name
|
373
384
|
for value in itertools.chain(self.shape, self.strides)
|
ninetoothed/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
1
|
+
import triton
|
2
|
+
|
3
|
+
|
4
|
+
def calculate_default_configs():
|
5
|
+
device = triton.runtime.driver.active.get_current_device()
|
6
|
+
properties = triton.runtime.driver.active.utils.get_device_properties(device)
|
7
|
+
max_shared_mem = properties["max_shared_mem"]
|
8
|
+
|
9
|
+
num_warps = 8
|
10
|
+
num_stages = max_shared_mem // 2**15
|
11
|
+
|
12
|
+
return num_warps, num_stages
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.16.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
|
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
|
+
Requires-Dist: sympy>=1.13.0
|
13
14
|
Requires-Dist: triton>=3.0.0
|
14
15
|
Provides-Extra: all
|
15
16
|
Requires-Dist: matplotlib>=3.9.0; extra == 'all'
|
@@ -0,0 +1,18 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=F2bxRNhzcGdtADA8RehTuf-QK0xnxno8kxvr6H2L5Tg,552
|
2
|
+
ninetoothed/aot.py,sha256=8ZCLtnsign14YvY7SXX5ASidhuUAhPwppTXUJNkQup4,6243
|
3
|
+
ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
|
4
|
+
ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
|
5
|
+
ninetoothed/generation.py,sha256=VIqSyZT4yHxY_a2QPmWW6jjALv3e1mohDqdRQBRYsAo,36462
|
6
|
+
ninetoothed/jit.py,sha256=CpeSkO_zUe9DwtTJ2K2H7Bwpx-FvIHfrgzOcEosfpek,2946
|
7
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
8
|
+
ninetoothed/make.py,sha256=fQKuRJL7HC2iGTAN323mlIWXz9Z3jotIoN68ur29Qlw,1834
|
9
|
+
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
10
|
+
ninetoothed/symbol.py,sha256=lJo3NL2-T7tKbKjb6MCRLMemN94mqS3bIiG943P0Mbo,7454
|
11
|
+
ninetoothed/tensor.py,sha256=gQEzHTcXqZVBFLc2YRfXTKxjxPWMxWN7fNl2BCfJwMs,14782
|
12
|
+
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
13
|
+
ninetoothed/utils.py,sha256=mtRXABBVPnlgd2n1REh9oB3s_5bUsKhd3iwu3oJ5DSQ,338
|
14
|
+
ninetoothed/visualization.py,sha256=zlMH-0WplaboePGzcbpcj4UovpX0k2r4SysSPsNS4r4,3674
|
15
|
+
ninetoothed-0.16.0.dist-info/METADATA,sha256=nkq3iImebtmcEs-bZq2zfF2_QxrZD9IWky1S86OnUMA,7340
|
16
|
+
ninetoothed-0.16.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
ninetoothed-0.16.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
18
|
+
ninetoothed-0.16.0.dist-info/RECORD,,
|
@@ -1,17 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=zGaZiUzwJZ2jfwLxp7lT8ll_V5ngP5QYrfVbapftbCY,522
|
2
|
-
ninetoothed/aot.py,sha256=5P9s-KAA7xNNdK8_fbCZEIteQlbaB_1wOl8_rEBQg9U,6128
|
3
|
-
ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
|
4
|
-
ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
|
5
|
-
ninetoothed/generation.py,sha256=Gmeh9OPmWZmF9CUY-UIIBPi-SOjFCxZjvXNwqX3uD84,30963
|
6
|
-
ninetoothed/jit.py,sha256=0MFbFIODtw-bxuOC7WByxiVtQMeyvZkoDxvfAZ9rIFQ,2120
|
7
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
8
|
-
ninetoothed/make.py,sha256=wRr3JwGt5E2OCquq_nzBZljdW-AJPOqH49cM08gwl4A,1287
|
9
|
-
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
10
|
-
ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
|
11
|
-
ninetoothed/tensor.py,sha256=ByTnoeqxD9lXprvy1DDp5L-zU2up52-jop9AAUrSTYk,14347
|
12
|
-
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
13
|
-
ninetoothed/visualization.py,sha256=zlMH-0WplaboePGzcbpcj4UovpX0k2r4SysSPsNS4r4,3674
|
14
|
-
ninetoothed-0.15.1.dist-info/METADATA,sha256=6RA1-6fYFfSTJnnwsRoVb4yIRrn4kfLhN47GNvmGji0,7311
|
15
|
-
ninetoothed-0.15.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
-
ninetoothed-0.15.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
17
|
-
ninetoothed-0.15.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|