ninetoothed 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/__init__.py CHANGED
@@ -13,12 +13,13 @@ from ninetoothed.dtype import (
13
13
  )
14
14
  from ninetoothed.jit import jit
15
15
  from ninetoothed.make import make
16
- from ninetoothed.symbol import Symbol
16
+ from ninetoothed.symbol import Symbol, block_size
17
17
  from ninetoothed.tensor import Tensor
18
18
 
19
19
  __all__ = [
20
20
  "Symbol",
21
21
  "Tensor",
22
+ "block_size",
22
23
  "float16",
23
24
  "float32",
24
25
  "float64",
ninetoothed/aot.py CHANGED
@@ -4,7 +4,7 @@ import subprocess
4
4
  import tempfile
5
5
  import uuid
6
6
 
7
- from ninetoothed.dtype import int64, uint64
7
+ from ninetoothed.dtype import int64
8
8
  from ninetoothed.generation import CACHE_DIR, CodeGenerator
9
9
  from ninetoothed.tensor import Tensor
10
10
 
@@ -36,7 +36,13 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
36
36
 
37
37
  code_generator = CodeGenerator()
38
38
  source_file = code_generator(
39
- func, caller=caller, kernel_name=kernel_name, prettify=False
39
+ func,
40
+ caller=caller,
41
+ kernel_name=kernel_name,
42
+ num_warps=num_warps,
43
+ num_stages=num_stages,
44
+ max_num_configs=None,
45
+ prettify=False,
40
46
  )
41
47
 
42
48
  tensors = code_generator.tensors
@@ -55,7 +61,7 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
55
61
 
56
62
  param_types.append(f"*{dtype}")
57
63
  elif Tensor.size_pattern().fullmatch(param):
58
- param_types.append(uint64)
64
+ param_types.append(int64)
59
65
  elif Tensor.stride_pattern().fullmatch(param):
60
66
  param_types.append(int64)
61
67
 
ninetoothed/generation.py CHANGED
@@ -5,11 +5,20 @@ import functools
5
5
  import hashlib
6
6
  import inspect
7
7
  import itertools
8
+ import json
8
9
  import math
10
+ import os
9
11
  import pathlib
12
+ import random
13
+ import shutil
10
14
  import subprocess
15
+ import tempfile
16
+ import time
17
+ import uuid
11
18
 
19
+ import sympy
12
20
  import triton
21
+ import triton.language as tl
13
22
 
14
23
  import ninetoothed.naming as naming
15
24
  from ninetoothed.cudaifier import Cudaifier
@@ -19,19 +28,47 @@ from ninetoothed.tensor import Tensor
19
28
  from ninetoothed.torchifier import Torchifier
20
29
 
21
30
  CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
31
+ CACHE_DIR.mkdir(exist_ok=True)
22
32
 
23
33
 
24
34
  class CodeGenerator(ast.NodeTransformer):
25
35
  def __init__(self):
26
36
  super().__init__()
27
37
 
28
- self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
38
+ cache_file = CACHE_DIR / "code_generator_cache.json"
29
39
 
30
- self._MIN_PRODUCT = 2**10
40
+ log2_min_num_elements = 4
31
41
 
32
- self._MAX_PRODUCT = 2**20
42
+ if cache_file.exists():
43
+ with open(cache_file) as f:
44
+ cache = json.load(f)
33
45
 
34
- def __call__(self, func, caller, kernel_name, prettify):
46
+ log2_max_num_elements = cache["log2_max_num_elements"]
47
+ else:
48
+ log2_max_num_elements = _determine_log2_max_num_elements_per_block(
49
+ log2_min_num_elements
50
+ )
51
+
52
+ cache = {"log2_max_num_elements": log2_max_num_elements}
53
+
54
+ with open(cache_file, "w") as f:
55
+ json.dump(cache, f, indent=4)
56
+ f.write("\n")
57
+
58
+ self._min_num_elements = 2**log2_min_num_elements
59
+
60
+ self._max_num_elements = 2**log2_max_num_elements
61
+
62
+ def __call__(
63
+ self,
64
+ func,
65
+ caller,
66
+ kernel_name,
67
+ num_warps,
68
+ num_stages,
69
+ max_num_configs,
70
+ prettify,
71
+ ):
35
72
  def _get_tree(func):
36
73
  module = ast.parse(inspect.getsource(inspect.getmodule(func)))
37
74
 
@@ -63,6 +100,12 @@ class CodeGenerator(ast.NodeTransformer):
63
100
 
64
101
  self._caller = caller
65
102
 
103
+ self._num_wraps = num_warps
104
+
105
+ self._num_stages = num_stages
106
+
107
+ self._max_num_configs = max_num_configs
108
+
66
109
  self._context = inspect.get_annotations(func)
67
110
 
68
111
  self._args = list(self._context.values())
@@ -82,6 +125,7 @@ class CodeGenerator(ast.NodeTransformer):
82
125
  dependencies = _find_dependencies(func)
83
126
  source = "\n\n".join((unparsed, dependencies)).strip()
84
127
  source = source.replace(func.__name__, kernel_name)
128
+ source += "\n"
85
129
 
86
130
  if prettify:
87
131
  for original, simplified in name_collector.simplified_names.items():
@@ -93,9 +137,7 @@ class CodeGenerator(ast.NodeTransformer):
93
137
  )
94
138
 
95
139
  digest = hashlib.sha256(source.encode("utf-8")).hexdigest()
96
- cache_dir = CACHE_DIR
97
- cache_dir.mkdir(exist_ok=True)
98
- cache_file = cache_dir / f"{digest}.py"
140
+ cache_file = CACHE_DIR / f"{digest}.py"
99
141
 
100
142
  if not cache_file.exists():
101
143
  with open(cache_file, "w", encoding="utf-8") as f:
@@ -110,11 +152,15 @@ class CodeGenerator(ast.NodeTransformer):
110
152
  def visit_Module(self, node):
111
153
  self.generic_visit(node)
112
154
 
113
- func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
155
+ if self._autotune is not None:
156
+ func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
157
+
158
+ node.body.append(
159
+ ast.parse(
160
+ f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}"
161
+ )
162
+ )
114
163
 
115
- node.body.append(
116
- ast.parse(f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}")
117
- )
118
164
  node.body.append(self._launch)
119
165
 
120
166
  return node
@@ -136,8 +182,13 @@ class CodeGenerator(ast.NodeTransformer):
136
182
  def visit_arguments(self, node):
137
183
  self.generic_visit(node)
138
184
 
139
- names_of_args = [arg.names() - {"ninetoothed"} for arg in self._args]
140
- names = functools.reduce(lambda x, y: x | y, names_of_args)
185
+ symbols = {
186
+ name.node.id: name
187
+ for arg in self._args
188
+ for name in arg.names()
189
+ if name != "ninetoothed"
190
+ }
191
+ names = symbols.keys()
141
192
  meta_names = {name for name in names if naming.is_meta(name)}
142
193
  non_meta_names = {name for name in names if name not in meta_names}
143
194
  non_meta_names |= {
@@ -146,6 +197,8 @@ class CodeGenerator(ast.NodeTransformer):
146
197
  if naming.is_constexpr(name)
147
198
  }
148
199
 
200
+ self._symbols = symbols
201
+
149
202
  non_meta_names = sorted(non_meta_names)
150
203
  meta_names = sorted(meta_names)
151
204
 
@@ -160,6 +213,12 @@ class CodeGenerator(ast.NodeTransformer):
160
213
  ]
161
214
 
162
215
  self._autotune = self._generate_autotune(non_meta_names, meta_names)
216
+
217
+ if self._autotune is not None:
218
+ self._func_name = self._func_name_with_auto_tuning
219
+ else:
220
+ self._func_name = self._func_def.name
221
+
163
222
  self._func_def.decorator_list = [Symbol("triton.jit").node]
164
223
 
165
224
  self._launch = self._generate_launch(non_meta_names, meta_names)
@@ -243,12 +302,69 @@ class CodeGenerator(ast.NodeTransformer):
243
302
  return isinstance(node, ast.Name) and node.id in self._context
244
303
 
245
304
  def _generate_autotune(self, params, meta):
246
- device = triton.runtime.driver.active.get_current_device()
247
- properties = triton.runtime.driver.active.utils.get_device_properties(device)
248
- max_shared_mem = properties["max_shared_mem"]
305
+ inequalities = True
306
+
307
+ for arg in self._args:
308
+ if arg.ndim == 0:
309
+ continue
310
+
311
+ num_elements = sympy.simplify(str(math.prod(arg.innermost().shape)))
312
+
313
+ inequalities &= num_elements <= self._max_num_elements
314
+ inequalities &= num_elements >= self._min_num_elements
315
+
316
+ values_of_meta_params = []
317
+
318
+ for param in meta:
319
+ symbol = self._symbols[param]
320
+
321
+ values = range(symbol.lower_bound, symbol.upper_bound + 1)
322
+
323
+ if symbol.power_of_two:
324
+ values = tuple(value for value in values if value & (value - 1) == 0)
325
+ else:
326
+ values = tuple(values)
327
+
328
+ values_of_meta_params.append(values)
329
+
330
+ max_values_of_non_meta_params = {}
331
+
332
+ for free_symbol in inequalities.free_symbols:
333
+ symbol_str = str(free_symbol)
334
+
335
+ if symbol_str in meta:
336
+ continue
337
+
338
+ symbol = self._symbols[symbol_str]
339
+
340
+ max_values_of_non_meta_params[symbol_str] = symbol.upper_bound
341
+
342
+ block_size_configs = []
343
+
344
+ for values in itertools.product(*values_of_meta_params):
345
+ config = {param: value for param, value in zip(meta, values)}
346
+
347
+ if sympy.logic.simplify_logic(
348
+ inequalities.subs(config | max_values_of_non_meta_params)
349
+ ):
350
+ block_size_configs.append(config)
351
+
352
+ if isinstance(self._num_wraps, collections.abc.Iterable):
353
+ num_warps_configs = self._num_wraps
354
+ else:
355
+ num_warps_configs = (self._num_wraps,)
356
+
357
+ if isinstance(self._num_stages, collections.abc.Iterable):
358
+ num_stages_configs = self._num_stages
359
+ else:
360
+ num_stages_configs = (self._num_stages,)
249
361
 
250
- num_warps = 8
251
- num_stages = max_shared_mem // 2**15
362
+ compiler_configs = tuple(
363
+ {"num_warps": num_warps, "num_stages": num_stages}
364
+ for num_warps, num_stages in itertools.product(
365
+ num_warps_configs, num_stages_configs
366
+ )
367
+ )
252
368
 
253
369
  configs = [
254
370
  ast.Call(
@@ -259,19 +375,38 @@ class CodeGenerator(ast.NodeTransformer):
259
375
  ),
260
376
  args=[
261
377
  ast.Dict(
262
- keys=[ast.Constant(value=param) for param in meta],
263
- values=[ast.Constant(value=value) for value in values],
378
+ keys=[
379
+ ast.Constant(value=param)
380
+ for param in block_size_config.keys()
381
+ ],
382
+ values=[
383
+ ast.Constant(value=value)
384
+ for value in block_size_config.values()
385
+ ],
264
386
  )
265
387
  ],
266
388
  keywords=[
267
- ast.keyword(arg="num_warps", value=ast.Constant(value=num_warps)),
268
- ast.keyword(arg="num_stages", value=ast.Constant(value=num_stages)),
389
+ ast.keyword(
390
+ arg="num_warps",
391
+ value=ast.Constant(value=compiler_config["num_warps"]),
392
+ ),
393
+ ast.keyword(
394
+ arg="num_stages",
395
+ value=ast.Constant(value=compiler_config["num_stages"]),
396
+ ),
269
397
  ],
270
398
  )
271
- for values in itertools.product(self._POWER_OF_TWOS, repeat=len(meta))
272
- if self._MIN_PRODUCT <= math.prod(values) <= self._MAX_PRODUCT
399
+ for block_size_config, compiler_config in itertools.product(
400
+ block_size_configs, compiler_configs
401
+ )
273
402
  ]
274
403
 
404
+ if self._max_num_configs is not None and len(configs) > self._max_num_configs:
405
+ configs = random.sample(configs, k=self._max_num_configs)
406
+
407
+ if not configs:
408
+ return None
409
+
275
410
  return ast.Call(
276
411
  func=ast.Attribute(
277
412
  value=ast.Name(id="ninetoothed", ctx=ast.Load()),
@@ -357,9 +492,7 @@ class CodeGenerator(ast.NodeTransformer):
357
492
  ast.Expr(
358
493
  ast.Call(
359
494
  func=ast.Subscript(
360
- value=ast.Name(
361
- id=self._func_name_with_auto_tuning, ctx=ast.Load()
362
- ),
495
+ value=ast.Name(id=self._func_name, ctx=ast.Load()),
363
496
  slice=self._generate_grid(),
364
497
  ctx=ast.Load(),
365
498
  ),
@@ -993,3 +1126,82 @@ class _FunctionDefFinder(ast.NodeVisitor):
993
1126
  self.result = node
994
1127
 
995
1128
  self.generic_visit(node)
1129
+
1130
+
1131
+ def _determine_log2_max_num_elements_per_block(
1132
+ min_exponent, max_exponent=30, num_iterations=3
1133
+ ):
1134
+ _profile_pseudo_add_kernel(1)
1135
+
1136
+ for n in range(min_exponent, max_exponent + 1):
1137
+ elapsed_time = 0
1138
+
1139
+ for _ in range(num_iterations):
1140
+ elapsed_time += _profile_pseudo_add_kernel(2**n)
1141
+
1142
+ average_elapsed_time = elapsed_time / num_iterations
1143
+
1144
+ if average_elapsed_time >= 1:
1145
+ return n - 1
1146
+
1147
+
1148
+ def _profile_pseudo_add_kernel(block_size):
1149
+ cache_dir = triton.runtime.cache.default_cache_dir()
1150
+ os.makedirs(cache_dir, exist_ok=True)
1151
+
1152
+ with tempfile.TemporaryDirectory() as backup_dir:
1153
+ backup_path = os.path.join(backup_dir, str(uuid.uuid4()))
1154
+
1155
+ if os.path.exists(backup_path):
1156
+ shutil.rmtree(backup_path)
1157
+
1158
+ shutil.move(cache_dir, backup_path)
1159
+
1160
+ try:
1161
+ start_time = time.time()
1162
+
1163
+ _run_pseudo_add_kernel(block_size)
1164
+
1165
+ end_time = time.time()
1166
+
1167
+ elapsed_time = end_time - start_time
1168
+ finally:
1169
+ if os.path.exists(cache_dir):
1170
+ shutil.rmtree(cache_dir)
1171
+
1172
+ shutil.move(backup_path, cache_dir)
1173
+
1174
+ return elapsed_time
1175
+
1176
+
1177
+ def _run_pseudo_add_kernel(block_size):
1178
+ @triton.jit
1179
+ def kernel(a_ptr, b_ptr, c_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
1180
+ pid = tl.program_id(0)
1181
+
1182
+ offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1183
+ mask = offs < num_elements
1184
+
1185
+ a = tl.load(a_ptr + offs, mask=mask)
1186
+ b = tl.load(b_ptr + offs, mask=mask)
1187
+
1188
+ c = a + b
1189
+
1190
+ tl.store(c_ptr + offs, c, mask=mask)
1191
+
1192
+ num_elements = 0
1193
+ shape = (num_elements,)
1194
+ dtype = tl.float32
1195
+
1196
+ a = Tensor(shape=shape, dtype=dtype)
1197
+ b = Tensor(shape=shape, dtype=dtype)
1198
+ c = Tensor(shape=shape, dtype=dtype)
1199
+
1200
+ def data_ptr():
1201
+ return 0
1202
+
1203
+ a.data_ptr = data_ptr
1204
+ b.data_ptr = data_ptr
1205
+ c.data_ptr = data_ptr
1206
+
1207
+ kernel[(1,)](a, b, c, num_elements, block_size)
ninetoothed/jit.py CHANGED
@@ -4,12 +4,25 @@ import sys
4
4
  from ninetoothed.generation import CodeGenerator
5
5
 
6
6
 
7
- def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
7
+ def jit(
8
+ func=None,
9
+ *,
10
+ caller="torch",
11
+ kernel_name=None,
12
+ num_warps=None,
13
+ num_stages=None,
14
+ max_num_configs=None,
15
+ _prettify=False,
16
+ ):
8
17
  """A decorator for generating compute kernels.
9
18
 
10
19
  :param func: The function to be compiled.
11
20
  :param caller: Who will call the compute kernel.
12
21
  :param kernel_name: The name for the generated kernel.
22
+ :param num_warps: The number of warps to use.
23
+ :param num_stages: The number of pipeline stages.
24
+ :param max_num_configs: The maximum number of auto-tuning
25
+ configurations to use.
13
26
  :param _prettify: Whether to prettify the generated code.
14
27
  :return: A handle to the compute kernel.
15
28
 
@@ -20,7 +33,15 @@ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
20
33
  """
21
34
 
22
35
  def wrapper(func):
23
- return JIT(func, caller=caller, kernel_name=kernel_name, _prettify=_prettify)()
36
+ return JIT(
37
+ func,
38
+ caller=caller,
39
+ kernel_name=kernel_name,
40
+ num_warps=num_warps,
41
+ num_stages=num_stages,
42
+ max_num_configs=max_num_configs,
43
+ _prettify=_prettify,
44
+ )()
24
45
 
25
46
  if func is None:
26
47
  return wrapper
@@ -29,7 +50,16 @@ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
29
50
 
30
51
 
31
52
  class JIT:
32
- def __init__(self, func, caller, kernel_name, _prettify=False):
53
+ def __init__(
54
+ self,
55
+ func,
56
+ caller,
57
+ kernel_name,
58
+ num_warps,
59
+ num_stages,
60
+ max_num_configs,
61
+ _prettify=False,
62
+ ):
33
63
  self.func = func
34
64
 
35
65
  self._caller = caller
@@ -39,12 +69,24 @@ class JIT:
39
69
  else:
40
70
  self._kernel_name = func.__name__
41
71
 
72
+ self._num_warps = num_warps
73
+
74
+ self._num_stages = num_stages
75
+
76
+ self._max_num_configs = max_num_configs
77
+
42
78
  self._prettify = _prettify
43
79
 
44
80
  def __call__(self):
45
81
  code_generator = CodeGenerator()
46
82
  source_file = code_generator(
47
- self.func, self._caller, self._kernel_name, self._prettify
83
+ self.func,
84
+ self._caller,
85
+ self._kernel_name,
86
+ self._num_warps,
87
+ self._num_stages,
88
+ self._max_num_configs,
89
+ self._prettify,
48
90
  )
49
91
  module = type(self)._import_from_path(source_file, source_file)
50
92
  module_vars = vars(module)
ninetoothed/make.py CHANGED
@@ -2,6 +2,7 @@ import inspect
2
2
 
3
3
  from ninetoothed.aot import aot
4
4
  from ninetoothed.jit import jit
5
+ from ninetoothed.utils import calculate_default_configs
5
6
 
6
7
 
7
8
  def make(
@@ -11,8 +12,9 @@ def make(
11
12
  caller="torch",
12
13
  kernel_name=None,
13
14
  output_dir=None,
14
- num_warps=4,
15
- num_stages=3,
15
+ num_warps=None,
16
+ num_stages=None,
17
+ max_num_configs=None,
16
18
  ):
17
19
  """Integrate the arrangement and the application of the tensors.
18
20
 
@@ -24,16 +26,33 @@ def make(
24
26
  :param output_dir: The directory to store the generated files.
25
27
  :param num_warps: The number of warps to use.
26
28
  :param num_stages: The number of pipeline stages.
29
+ :param max_num_configs: The maximum number of auto-tuning
30
+ configurations to use.
27
31
  :return: A handle to the compute kernel.
28
32
  """
29
33
 
34
+ default_num_warps, default_num_stages = calculate_default_configs()
35
+
36
+ if num_warps is None:
37
+ num_warps = default_num_warps
38
+
39
+ if num_stages is None:
40
+ num_stages = default_num_stages
41
+
30
42
  params = inspect.signature(application).parameters
31
43
  types = arrangement(*tensors)
32
44
  annotations = {param: type for param, type in zip(params, types)}
33
45
  application.__annotations__ = annotations
34
46
 
35
47
  if caller == "torch":
36
- return jit(application, caller=caller, kernel_name=kernel_name)
48
+ return jit(
49
+ application,
50
+ caller=caller,
51
+ kernel_name=kernel_name,
52
+ num_warps=num_warps,
53
+ num_stages=num_stages,
54
+ max_num_configs=max_num_configs,
55
+ )
37
56
 
38
57
  return aot(
39
58
  application,
ninetoothed/symbol.py CHANGED
@@ -12,9 +12,20 @@ class Symbol:
12
12
  :param expr: The expression used to construct the symbol.
13
13
  :param constexpr: Whether the symbol is a constexpr.
14
14
  :param mata: Whether the symbol is a meta.
15
+ :param lower_bound: The minimum value for the symbol's range.
16
+ :param upper_bound: The maximum value for the symbol's range.
17
+ :param power_of_two: Whether the value should be a power of two.
15
18
  """
16
19
 
17
- def __init__(self, expr, constexpr=None, meta=None):
20
+ def __init__(
21
+ self,
22
+ expr,
23
+ constexpr=None,
24
+ meta=None,
25
+ lower_bound=None,
26
+ upper_bound=None,
27
+ power_of_two=None,
28
+ ):
18
29
  if isinstance(expr, type(self)):
19
30
  self._node = expr._node
20
31
  return
@@ -43,6 +54,40 @@ class Symbol:
43
54
  if constexpr:
44
55
  self._node.id = naming.make_constexpr(self._node.id)
45
56
 
57
+ self._node.symbol = self
58
+
59
+ DEFAULT_LOWER_BOUND_FOR_META_SYMBOLS = 2**5
60
+ DEFAULT_UPPER_BOUND_FOR_META_SYMBOLS = 2**10
61
+ DEFAULT_POWER_OF_TWO_FOR_META_SYMBOLS = True
62
+
63
+ DEFAULT_LOWER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS = 1
64
+ DEFAULT_UPPER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS = 2**20
65
+ DEFAULT_POWER_OF_TWO_FOR_NON_META_CONSTEXPR_SYMBOLS = False
66
+
67
+ if lower_bound is not None:
68
+ self.lower_bound = lower_bound
69
+ else:
70
+ if meta:
71
+ self.lower_bound = DEFAULT_LOWER_BOUND_FOR_META_SYMBOLS
72
+ elif constexpr:
73
+ self.lower_bound = DEFAULT_LOWER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS
74
+
75
+ if upper_bound is not None:
76
+ self.upper_bound = upper_bound
77
+ else:
78
+ if meta:
79
+ self.upper_bound = DEFAULT_UPPER_BOUND_FOR_META_SYMBOLS
80
+ elif constexpr:
81
+ self.upper_bound = DEFAULT_UPPER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS
82
+
83
+ if power_of_two is not None:
84
+ self.power_of_two = power_of_two
85
+ else:
86
+ if meta:
87
+ self.power_of_two = DEFAULT_POWER_OF_TWO_FOR_META_SYMBOLS
88
+ elif constexpr:
89
+ self.power_of_two = DEFAULT_POWER_OF_TWO_FOR_NON_META_CONSTEXPR_SYMBOLS
90
+
46
91
  def __eq__(self, other):
47
92
  if isinstance(self._node, ast.Constant):
48
93
  if isinstance(other, Symbol) and isinstance(other._node, ast.Constant):
@@ -155,7 +200,7 @@ class Symbol:
155
200
  def visit_Name(self, node):
156
201
  self.generic_visit(node)
157
202
 
158
- self.names.add(node.id)
203
+ self.names.add(node.symbol)
159
204
 
160
205
  name_collector = NameCollector()
161
206
 
@@ -179,6 +224,24 @@ class Symbol:
179
224
  return isinstance(object, Symbol) and isinstance(object.node, ast.Name)
180
225
 
181
226
 
227
+ def block_size(lower_bound=None, upper_bound=None):
228
+ """Create a block size symbol that serves as a meta-parameter.
229
+
230
+ :param lower_bound: The lower bound for the block size's range.
231
+ :param upper_bound: The upper bound for the block size's range.
232
+ :return: A block size symbol that serves as a meta-parameter.
233
+ """
234
+
235
+ name = naming.auto_generate(f"BLOCK_SIZE_{block_size._num_block_sizes}")
236
+
237
+ block_size._num_block_sizes += 1
238
+
239
+ return Symbol(name, meta=True, lower_bound=lower_bound, upper_bound=upper_bound)
240
+
241
+
242
+ block_size._num_block_sizes = 0
243
+
244
+
182
245
  class _FindAndReplacer(ast.NodeTransformer):
183
246
  def __init__(self, targets, replacement):
184
247
  self._targets_unparsed = tuple(
ninetoothed/tensor.py CHANGED
@@ -14,7 +14,7 @@ class Tensor:
14
14
  :param dtype: The element type of the tensor.
15
15
  :param strides: The strides of the tensor.
16
16
  :param other: The values for out-of-bounds positions.
17
- :param constexpr_shape: Whether the sizes are constexpr.
17
+ :param shape_options: The options for configuring shape symbols.
18
18
  :param name: The name of the tensor.
19
19
  :param source: For internal use only.
20
20
  :param source_dims: For internal use only.
@@ -31,7 +31,7 @@ class Tensor:
31
31
  dtype=None,
32
32
  strides=None,
33
33
  other=None,
34
- constexpr_shape=None,
34
+ shape_options=None,
35
35
  name=None,
36
36
  source=None,
37
37
  source_dims=None,
@@ -48,9 +48,20 @@ class Tensor:
48
48
  self.name = naming.auto_generate(f"tensor_{type(self).num_instances}")
49
49
 
50
50
  if ndim is not None:
51
+ if shape_options is None:
52
+ shape_options = tuple({} for _ in range(ndim))
53
+
54
+ if isinstance(shape_options, dict):
55
+ shape_options = tuple(shape_options for _ in range(ndim))
56
+
57
+ shape_options = tuple(
58
+ size_options if size_options is not None else {}
59
+ for size_options in shape_options
60
+ )
61
+
51
62
  self.shape = (
52
- Symbol(self.size_string(i), constexpr=constexpr_shape)
53
- for i in range(ndim)
63
+ Symbol(self.size_string(i), **size_options)
64
+ for i, size_options in zip(range(ndim), shape_options)
54
65
  )
55
66
  self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
56
67
  else:
@@ -146,7 +157,7 @@ class Tensor:
146
157
  )
147
158
  outer_shape.append(new_size)
148
159
 
149
- new_stride = self_stride * stride // spacing
160
+ new_stride = self_stride * stride
150
161
  outer_strides.append(new_stride)
151
162
 
152
163
  inner_shape.append(tile_size)
@@ -364,10 +375,10 @@ class Tensor:
364
375
 
365
376
  def names(self):
366
377
  if self.ndim == 0:
367
- return {self.source.name}
378
+ return {Symbol(self.source.name)}
368
379
 
369
380
  return (
370
- {self.source.pointer_string()}
381
+ {Symbol(self.source.pointer_string())}
371
382
  | {
372
383
  name
373
384
  for value in itertools.chain(self.shape, self.strides)
ninetoothed/utils.py ADDED
@@ -0,0 +1,12 @@
1
+ import triton
2
+
3
+
4
+ def calculate_default_configs():
5
+ device = triton.runtime.driver.active.get_current_device()
6
+ properties = triton.runtime.driver.active.utils.get_device_properties(device)
7
+ max_shared_mem = properties["max_shared_mem"]
8
+
9
+ num_warps = 8
10
+ num_stages = max_shared_mem // 2**15
11
+
12
+ return num_warps, num_stages
@@ -118,10 +118,16 @@ def _visualize_unit_square(ax, x, y, color):
118
118
 
119
119
 
120
120
  def _visualize_rect(ax, width, height, x, y, color):
121
- pos_x, pos_y = zip(*_verts_of_rect(width, height, x, y))
122
-
123
- ax.fill(pos_x, pos_y, color)
124
- ax.plot(pos_x + (pos_x[0],), pos_y + (pos_y[0],), "k")
121
+ ax.add_patch(
122
+ plt.Rectangle(
123
+ (x, y),
124
+ width,
125
+ height,
126
+ edgecolor="k",
127
+ facecolor=color,
128
+ linewidth=plt.rcParams["lines.linewidth"],
129
+ )
130
+ )
125
131
 
126
132
 
127
133
  def _verts_of_rect(width, height, x, y):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.15.0
3
+ Version: 0.16.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
10
10
  Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
+ Requires-Dist: sympy>=1.13.0
13
14
  Requires-Dist: triton>=3.0.0
14
15
  Provides-Extra: all
15
16
  Requires-Dist: matplotlib>=3.9.0; extra == 'all'
@@ -0,0 +1,18 @@
1
+ ninetoothed/__init__.py,sha256=F2bxRNhzcGdtADA8RehTuf-QK0xnxno8kxvr6H2L5Tg,552
2
+ ninetoothed/aot.py,sha256=8ZCLtnsign14YvY7SXX5ASidhuUAhPwppTXUJNkQup4,6243
3
+ ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
4
+ ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
5
+ ninetoothed/generation.py,sha256=VIqSyZT4yHxY_a2QPmWW6jjALv3e1mohDqdRQBRYsAo,36462
6
+ ninetoothed/jit.py,sha256=CpeSkO_zUe9DwtTJ2K2H7Bwpx-FvIHfrgzOcEosfpek,2946
7
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
8
+ ninetoothed/make.py,sha256=fQKuRJL7HC2iGTAN323mlIWXz9Z3jotIoN68ur29Qlw,1834
9
+ ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
10
+ ninetoothed/symbol.py,sha256=lJo3NL2-T7tKbKjb6MCRLMemN94mqS3bIiG943P0Mbo,7454
11
+ ninetoothed/tensor.py,sha256=gQEzHTcXqZVBFLc2YRfXTKxjxPWMxWN7fNl2BCfJwMs,14782
12
+ ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
13
+ ninetoothed/utils.py,sha256=mtRXABBVPnlgd2n1REh9oB3s_5bUsKhd3iwu3oJ5DSQ,338
14
+ ninetoothed/visualization.py,sha256=zlMH-0WplaboePGzcbpcj4UovpX0k2r4SysSPsNS4r4,3674
15
+ ninetoothed-0.16.0.dist-info/METADATA,sha256=nkq3iImebtmcEs-bZq2zfF2_QxrZD9IWky1S86OnUMA,7340
16
+ ninetoothed-0.16.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ ninetoothed-0.16.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
+ ninetoothed-0.16.0.dist-info/RECORD,,
@@ -1,17 +0,0 @@
1
- ninetoothed/__init__.py,sha256=zGaZiUzwJZ2jfwLxp7lT8ll_V5ngP5QYrfVbapftbCY,522
2
- ninetoothed/aot.py,sha256=1hzl4-6MqscB4tDMqmLCOTlyzsYkbY20EnmDgHO8hU4,6137
3
- ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
4
- ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
5
- ninetoothed/generation.py,sha256=QHrK7DOuJo5wEV-5HAqu_e-suuD4TPPNmCzrUMRYF2w,30940
6
- ninetoothed/jit.py,sha256=0MFbFIODtw-bxuOC7WByxiVtQMeyvZkoDxvfAZ9rIFQ,2120
7
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
8
- ninetoothed/make.py,sha256=wRr3JwGt5E2OCquq_nzBZljdW-AJPOqH49cM08gwl4A,1287
9
- ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
10
- ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
11
- ninetoothed/tensor.py,sha256=W1XY8_vaYmszX4lIWuas-ZKGbbdEZU7Z5h1A4FBXDXg,14358
12
- ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
13
- ninetoothed/visualization.py,sha256=IZ7iTT4dl5_JFbO-WfSWPFWpgkyPr4nylwhSZVy8gss,3601
14
- ninetoothed-0.15.0.dist-info/METADATA,sha256=UmND-TBDf7vrdii9dhiOZmiZBzrcG8-xEniThTSikqM,7311
15
- ninetoothed-0.15.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- ninetoothed-0.15.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
17
- ninetoothed-0.15.0.dist-info/RECORD,,