ninetoothed 0.15.1__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/__init__.py CHANGED
@@ -13,12 +13,13 @@ from ninetoothed.dtype import (
13
13
  )
14
14
  from ninetoothed.jit import jit
15
15
  from ninetoothed.make import make
16
- from ninetoothed.symbol import Symbol
16
+ from ninetoothed.symbol import Symbol, block_size
17
17
  from ninetoothed.tensor import Tensor
18
18
 
19
19
  __all__ = [
20
20
  "Symbol",
21
21
  "Tensor",
22
+ "block_size",
22
23
  "float16",
23
24
  "float32",
24
25
  "float64",
ninetoothed/aot.py CHANGED
@@ -31,12 +31,18 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
31
31
 
32
32
  _HEADER_PATH.parent.mkdir(exist_ok=True)
33
33
 
34
- if not _HEADER_PATH.exists():
34
+ if not _HEADER_PATH.exists() or _HEADER_PATH.read_text() != _HEADER_CONTENT:
35
35
  _HEADER_PATH.write_text(_HEADER_CONTENT)
36
36
 
37
37
  code_generator = CodeGenerator()
38
38
  source_file = code_generator(
39
- func, caller=caller, kernel_name=kernel_name, prettify=False
39
+ func,
40
+ caller=caller,
41
+ kernel_name=kernel_name,
42
+ num_warps=num_warps,
43
+ num_stages=num_stages,
44
+ max_num_configs=None,
45
+ prettify=False,
40
46
  )
41
47
 
42
48
  tensors = code_generator.tensors
@@ -85,20 +91,29 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
85
91
 
86
92
  c_header_file_name = f"{kernel_name}.{signature_hash}.h"
87
93
  c_header_file = output_contents[c_header_file_name]
88
- c_header_file = f"{c_header_file}\n{unparser.header};\n"
94
+ c_header_file = f'{c_header_file}\n#ifdef __cplusplus\nextern "C" {unparser.header};\n#else\n{unparser.header};\n#endif\n'
89
95
  c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
90
96
  output_contents[c_header_file_name] = c_header_file
91
97
 
92
98
  return output_contents
93
99
 
94
100
 
95
- _HEADER_CONTENT = """#include <stdint.h>
101
+ _HEADER_CONTENT = """#ifndef NINETOOTHED_H
102
+ #define NINETOOTHED_H
103
+
104
+ #include <stdint.h>
96
105
 
97
106
  typedef struct {
98
- uintptr_t data;
107
+ void *data;
99
108
  uint64_t *shape;
100
109
  int64_t *strides;
101
110
  } NineToothedTensor;
111
+
112
+ typedef void *NineToothedStream;
113
+
114
+ typedef int NineToothedResult;
115
+
116
+ #endif // NINETOOTHED_H
102
117
  """
103
118
 
104
119
  _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
@@ -129,9 +144,9 @@ class _Unparser:
129
144
  return f"return {self._generic_unparse(call)};"
130
145
 
131
146
  def _unparse_FunctionDef(self, node):
132
- params = ["CUstream stream"]
147
+ params = ["NineToothedStream stream"]
133
148
  params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
134
- header = f"CUresult {node.name}({', '.join(params)})"
149
+ header = f"NineToothedResult {node.name}({', '.join(params)})"
135
150
 
136
151
  self.header = header
137
152
 
ninetoothed/generation.py CHANGED
@@ -5,11 +5,21 @@ import functools
5
5
  import hashlib
6
6
  import inspect
7
7
  import itertools
8
+ import json
8
9
  import math
10
+ import os
9
11
  import pathlib
12
+ import random
13
+ import shutil
10
14
  import subprocess
15
+ import tempfile
16
+ import time
17
+ import uuid
11
18
 
19
+ import sympy
12
20
  import triton
21
+ import triton.language as tl
22
+ from triton.language.extra import libdevice
13
23
 
14
24
  import ninetoothed.naming as naming
15
25
  from ninetoothed.cudaifier import Cudaifier
@@ -19,19 +29,47 @@ from ninetoothed.tensor import Tensor
19
29
  from ninetoothed.torchifier import Torchifier
20
30
 
21
31
  CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
32
+ CACHE_DIR.mkdir(exist_ok=True)
22
33
 
23
34
 
24
35
  class CodeGenerator(ast.NodeTransformer):
25
36
  def __init__(self):
26
37
  super().__init__()
27
38
 
28
- self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
39
+ cache_file = CACHE_DIR / "code_generator_cache.json"
29
40
 
30
- self._MIN_PRODUCT = 2**10
41
+ log2_min_num_elements = 4
31
42
 
32
- self._MAX_PRODUCT = 2**20
43
+ if cache_file.exists():
44
+ with open(cache_file) as f:
45
+ cache = json.load(f)
33
46
 
34
- def __call__(self, func, caller, kernel_name, prettify):
47
+ log2_max_num_elements = cache["log2_max_num_elements"]
48
+ else:
49
+ log2_max_num_elements = _determine_log2_max_num_elements_per_block(
50
+ log2_min_num_elements
51
+ )
52
+
53
+ cache = {"log2_max_num_elements": log2_max_num_elements}
54
+
55
+ with open(cache_file, "w") as f:
56
+ json.dump(cache, f, indent=4)
57
+ f.write("\n")
58
+
59
+ self._min_num_elements = 2**log2_min_num_elements
60
+
61
+ self._max_num_elements = 2**log2_max_num_elements
62
+
63
+ def __call__(
64
+ self,
65
+ func,
66
+ caller,
67
+ kernel_name,
68
+ num_warps,
69
+ num_stages,
70
+ max_num_configs,
71
+ prettify,
72
+ ):
35
73
  def _get_tree(func):
36
74
  module = ast.parse(inspect.getsource(inspect.getmodule(func)))
37
75
 
@@ -63,6 +101,12 @@ class CodeGenerator(ast.NodeTransformer):
63
101
 
64
102
  self._caller = caller
65
103
 
104
+ self._num_wraps = num_warps
105
+
106
+ self._num_stages = num_stages
107
+
108
+ self._max_num_configs = max_num_configs
109
+
66
110
  self._context = inspect.get_annotations(func)
67
111
 
68
112
  self._args = list(self._context.values())
@@ -94,9 +138,7 @@ class CodeGenerator(ast.NodeTransformer):
94
138
  )
95
139
 
96
140
  digest = hashlib.sha256(source.encode("utf-8")).hexdigest()
97
- cache_dir = CACHE_DIR
98
- cache_dir.mkdir(exist_ok=True)
99
- cache_file = cache_dir / f"{digest}.py"
141
+ cache_file = CACHE_DIR / f"{digest}.py"
100
142
 
101
143
  if not cache_file.exists():
102
144
  with open(cache_file, "w", encoding="utf-8") as f:
@@ -111,11 +153,15 @@ class CodeGenerator(ast.NodeTransformer):
111
153
  def visit_Module(self, node):
112
154
  self.generic_visit(node)
113
155
 
114
- func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
156
+ if self._autotune is not None:
157
+ func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
158
+
159
+ node.body.append(
160
+ ast.parse(
161
+ f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}"
162
+ )
163
+ )
115
164
 
116
- node.body.append(
117
- ast.parse(f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}")
118
- )
119
165
  node.body.append(self._launch)
120
166
 
121
167
  return node
@@ -137,8 +183,13 @@ class CodeGenerator(ast.NodeTransformer):
137
183
  def visit_arguments(self, node):
138
184
  self.generic_visit(node)
139
185
 
140
- names_of_args = [arg.names() - {"ninetoothed"} for arg in self._args]
141
- names = functools.reduce(lambda x, y: x | y, names_of_args)
186
+ symbols = {
187
+ name.node.id: name
188
+ for arg in self._args
189
+ for name in arg.names()
190
+ if name != "ninetoothed"
191
+ }
192
+ names = symbols.keys()
142
193
  meta_names = {name for name in names if naming.is_meta(name)}
143
194
  non_meta_names = {name for name in names if name not in meta_names}
144
195
  non_meta_names |= {
@@ -147,6 +198,8 @@ class CodeGenerator(ast.NodeTransformer):
147
198
  if naming.is_constexpr(name)
148
199
  }
149
200
 
201
+ self._symbols = symbols
202
+
150
203
  non_meta_names = sorted(non_meta_names)
151
204
  meta_names = sorted(meta_names)
152
205
 
@@ -161,12 +214,53 @@ class CodeGenerator(ast.NodeTransformer):
161
214
  ]
162
215
 
163
216
  self._autotune = self._generate_autotune(non_meta_names, meta_names)
217
+
218
+ if self._autotune is not None:
219
+ self._func_name = self._func_name_with_auto_tuning
220
+ else:
221
+ self._func_name = self._func_def.name
222
+
164
223
  self._func_def.decorator_list = [Symbol("triton.jit").node]
165
224
 
166
225
  self._launch = self._generate_launch(non_meta_names, meta_names)
167
226
 
168
227
  return node
169
228
 
229
+ def visit_Call(self, node):
230
+ def _offsets(tensor, dim=None):
231
+ if dim is None:
232
+ return tensor._last_generated_overall_offsets.node
233
+
234
+ offsets = tensor._last_generated_offsets
235
+
236
+ if dim < 0:
237
+ dim += tensor.source.ndim
238
+
239
+ return sum(
240
+ offsets[dim][target_dim] for target_dim in range(tensor.target.ndim)
241
+ ).node
242
+
243
+ func = node.func
244
+ args = node.args
245
+
246
+ if isinstance(func, ast.Attribute):
247
+ if func.attr == "offsets":
248
+ value = func.value
249
+
250
+ if self._in_context(value):
251
+ tensor = self._context[value.id]
252
+ elif isinstance(value, ast.Subscript) and self._in_context(value.value):
253
+ tensor = self._context[value.value.id]
254
+
255
+ self.visit(value)
256
+
257
+ # TODO: Add error handling.
258
+ return _offsets(tensor, ast.literal_eval(args[0]) if args else None)
259
+
260
+ self.generic_visit(node)
261
+
262
+ return node
263
+
170
264
  def visit_Subscript(self, node):
171
265
  if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
172
266
  value = self._context[node.value.id]
@@ -184,13 +278,24 @@ class CodeGenerator(ast.NodeTransformer):
184
278
  return node
185
279
 
186
280
  def visit_Attribute(self, node):
187
- if self._in_context(node.value):
188
- value = self._context[node.value.id]
281
+ value = node.value
189
282
 
190
- if isinstance(value, Tensor):
191
- inner = value.dtype
283
+ if isinstance(value, ast.Attribute):
284
+ value = self.visit_Attribute(value)
285
+
286
+ if self._in_context(value):
287
+ value = self._context[value.id].dtype
288
+
289
+ if isinstance(value, Tensor):
290
+ attr = getattr(value, node.attr)
192
291
 
193
- return Symbol(getattr(inner, node.attr)).node
292
+ if node.attr == "dtype" and attr is None:
293
+ return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
294
+
295
+ if isinstance(attr, Tensor):
296
+ return attr
297
+
298
+ return Symbol(attr).node
194
299
 
195
300
  self.generic_visit(node)
196
301
 
@@ -244,12 +349,69 @@ class CodeGenerator(ast.NodeTransformer):
244
349
  return isinstance(node, ast.Name) and node.id in self._context
245
350
 
246
351
  def _generate_autotune(self, params, meta):
247
- device = triton.runtime.driver.active.get_current_device()
248
- properties = triton.runtime.driver.active.utils.get_device_properties(device)
249
- max_shared_mem = properties["max_shared_mem"]
352
+ inequalities = True
353
+
354
+ for arg in self._args:
355
+ if arg.ndim == 0:
356
+ continue
357
+
358
+ num_elements = sympy.simplify(str(math.prod(arg.innermost().shape)))
359
+
360
+ inequalities &= num_elements <= self._max_num_elements
361
+ inequalities &= num_elements >= self._min_num_elements
362
+
363
+ values_of_meta_params = []
364
+
365
+ for param in meta:
366
+ symbol = self._symbols[param]
367
+
368
+ values = range(symbol.lower_bound, symbol.upper_bound + 1)
369
+
370
+ if symbol.power_of_two:
371
+ values = tuple(value for value in values if value & (value - 1) == 0)
372
+ else:
373
+ values = tuple(values)
374
+
375
+ values_of_meta_params.append(values)
376
+
377
+ max_values_of_non_meta_params = {}
378
+
379
+ for free_symbol in inequalities.free_symbols:
380
+ symbol_str = str(free_symbol)
381
+
382
+ if symbol_str in meta:
383
+ continue
384
+
385
+ symbol = self._symbols[symbol_str]
386
+
387
+ max_values_of_non_meta_params[symbol_str] = symbol.upper_bound
250
388
 
251
- num_warps = 8
252
- num_stages = max_shared_mem // 2**15
389
+ block_size_configs = []
390
+
391
+ for values in itertools.product(*values_of_meta_params):
392
+ config = {param: value for param, value in zip(meta, values)}
393
+
394
+ if sympy.logic.simplify_logic(
395
+ inequalities.subs(config | max_values_of_non_meta_params)
396
+ ):
397
+ block_size_configs.append(config)
398
+
399
+ if isinstance(self._num_wraps, collections.abc.Iterable):
400
+ num_warps_configs = self._num_wraps
401
+ else:
402
+ num_warps_configs = (self._num_wraps,)
403
+
404
+ if isinstance(self._num_stages, collections.abc.Iterable):
405
+ num_stages_configs = self._num_stages
406
+ else:
407
+ num_stages_configs = (self._num_stages,)
408
+
409
+ compiler_configs = tuple(
410
+ {"num_warps": num_warps, "num_stages": num_stages}
411
+ for num_warps, num_stages in itertools.product(
412
+ num_warps_configs, num_stages_configs
413
+ )
414
+ )
253
415
 
254
416
  configs = [
255
417
  ast.Call(
@@ -260,19 +422,38 @@ class CodeGenerator(ast.NodeTransformer):
260
422
  ),
261
423
  args=[
262
424
  ast.Dict(
263
- keys=[ast.Constant(value=param) for param in meta],
264
- values=[ast.Constant(value=value) for value in values],
425
+ keys=[
426
+ ast.Constant(value=param)
427
+ for param in block_size_config.keys()
428
+ ],
429
+ values=[
430
+ ast.Constant(value=value)
431
+ for value in block_size_config.values()
432
+ ],
265
433
  )
266
434
  ],
267
435
  keywords=[
268
- ast.keyword(arg="num_warps", value=ast.Constant(value=num_warps)),
269
- ast.keyword(arg="num_stages", value=ast.Constant(value=num_stages)),
436
+ ast.keyword(
437
+ arg="num_warps",
438
+ value=ast.Constant(value=compiler_config["num_warps"]),
439
+ ),
440
+ ast.keyword(
441
+ arg="num_stages",
442
+ value=ast.Constant(value=compiler_config["num_stages"]),
443
+ ),
270
444
  ],
271
445
  )
272
- for values in itertools.product(self._POWER_OF_TWOS, repeat=len(meta))
273
- if self._MIN_PRODUCT <= math.prod(values) <= self._MAX_PRODUCT
446
+ for block_size_config, compiler_config in itertools.product(
447
+ block_size_configs, compiler_configs
448
+ )
274
449
  ]
275
450
 
451
+ if self._max_num_configs is not None and len(configs) > self._max_num_configs:
452
+ configs = random.sample(configs, k=self._max_num_configs)
453
+
454
+ if not configs:
455
+ return None
456
+
276
457
  return ast.Call(
277
458
  func=ast.Attribute(
278
459
  value=ast.Name(id="ninetoothed", ctx=ast.Load()),
@@ -358,9 +539,7 @@ class CodeGenerator(ast.NodeTransformer):
358
539
  ast.Expr(
359
540
  ast.Call(
360
541
  func=ast.Subscript(
361
- value=ast.Name(
362
- id=self._func_name_with_auto_tuning, ctx=ast.Load()
363
- ),
542
+ value=ast.Name(id=self._func_name, ctx=ast.Load()),
364
543
  slice=self._generate_grid(),
365
544
  ctx=ast.Load(),
366
545
  ),
@@ -428,6 +607,8 @@ class CodeGenerator(ast.NodeTransformer):
428
607
  indices = self._complete_indices(tensor, indices)
429
608
  offsets = type(self)._generate_offsets(tensor, indices)
430
609
 
610
+ tensor._last_generated_offsets = offsets
611
+
431
612
  for source_dim in range(tensor.source.ndim):
432
613
  for target_dim in range(tensor.target.ndim):
433
614
  if target_dim not in invariant_target_dims:
@@ -452,7 +633,7 @@ class CodeGenerator(ast.NodeTransformer):
452
633
  * tensor.source.strides[source_dim]
453
634
  )
454
635
 
455
- pointers = name_for_pointers + sum(
636
+ overall_offsets = sum(
456
637
  offsets[source_dim][target_dim][
457
638
  type(self)._generate_slices(tensor, target_dim)
458
639
  ]
@@ -462,6 +643,10 @@ class CodeGenerator(ast.NodeTransformer):
462
643
  if target_dim not in invariant_target_dims
463
644
  and offsets[source_dim][target_dim] != 0
464
645
  )
646
+
647
+ tensor._last_generated_overall_offsets = overall_offsets
648
+
649
+ pointers = name_for_pointers + overall_offsets
465
650
  mask = functools.reduce(
466
651
  lambda x, y: x & y,
467
652
  (
@@ -848,6 +1033,9 @@ class _Inliner(ast.NodeTransformer):
848
1033
  if func_def is None:
849
1034
  return None, []
850
1035
 
1036
+ if inspect.getmodule(func) is libdevice:
1037
+ return None, []
1038
+
851
1039
  collector = _ImportCollector()
852
1040
  collector.visit(ast.parse(inspect.getsource(inspect.getmodule(func))))
853
1041
  self.imports.extend(collector.imports)
@@ -994,3 +1182,82 @@ class _FunctionDefFinder(ast.NodeVisitor):
994
1182
  self.result = node
995
1183
 
996
1184
  self.generic_visit(node)
1185
+
1186
+
1187
+ def _determine_log2_max_num_elements_per_block(
1188
+ min_exponent, max_exponent=30, num_iterations=3
1189
+ ):
1190
+ _profile_pseudo_add_kernel(1)
1191
+
1192
+ for n in range(min_exponent, max_exponent + 1):
1193
+ elapsed_time = 0
1194
+
1195
+ for _ in range(num_iterations):
1196
+ elapsed_time += _profile_pseudo_add_kernel(2**n)
1197
+
1198
+ average_elapsed_time = elapsed_time / num_iterations
1199
+
1200
+ if average_elapsed_time >= 1:
1201
+ return n - 1
1202
+
1203
+
1204
+ def _profile_pseudo_add_kernel(block_size):
1205
+ cache_dir = triton.runtime.cache.default_cache_dir()
1206
+ os.makedirs(cache_dir, exist_ok=True)
1207
+
1208
+ with tempfile.TemporaryDirectory() as backup_dir:
1209
+ backup_path = os.path.join(backup_dir, str(uuid.uuid4()))
1210
+
1211
+ if os.path.exists(backup_path):
1212
+ shutil.rmtree(backup_path)
1213
+
1214
+ shutil.move(cache_dir, backup_path)
1215
+
1216
+ try:
1217
+ start_time = time.time()
1218
+
1219
+ _run_pseudo_add_kernel(block_size)
1220
+
1221
+ end_time = time.time()
1222
+
1223
+ elapsed_time = end_time - start_time
1224
+ finally:
1225
+ if os.path.exists(cache_dir):
1226
+ shutil.rmtree(cache_dir)
1227
+
1228
+ shutil.move(backup_path, cache_dir)
1229
+
1230
+ return elapsed_time
1231
+
1232
+
1233
+ def _run_pseudo_add_kernel(block_size):
1234
+ @triton.jit
1235
+ def kernel(a_ptr, b_ptr, c_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
1236
+ pid = tl.program_id(0)
1237
+
1238
+ offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1239
+ mask = offs < num_elements
1240
+
1241
+ a = tl.load(a_ptr + offs, mask=mask)
1242
+ b = tl.load(b_ptr + offs, mask=mask)
1243
+
1244
+ c = a + b
1245
+
1246
+ tl.store(c_ptr + offs, c, mask=mask)
1247
+
1248
+ num_elements = 0
1249
+ shape = (num_elements,)
1250
+ dtype = tl.float32
1251
+
1252
+ a = Tensor(shape=shape, dtype=dtype)
1253
+ b = Tensor(shape=shape, dtype=dtype)
1254
+ c = Tensor(shape=shape, dtype=dtype)
1255
+
1256
+ def data_ptr():
1257
+ return 0
1258
+
1259
+ a.data_ptr = data_ptr
1260
+ b.data_ptr = data_ptr
1261
+ c.data_ptr = data_ptr
1262
+
1263
+ kernel[(1,)](a, b, c, num_elements, block_size)
ninetoothed/jit.py CHANGED
@@ -4,12 +4,25 @@ import sys
4
4
  from ninetoothed.generation import CodeGenerator
5
5
 
6
6
 
7
- def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
7
+ def jit(
8
+ func=None,
9
+ *,
10
+ caller="torch",
11
+ kernel_name=None,
12
+ num_warps=None,
13
+ num_stages=None,
14
+ max_num_configs=None,
15
+ _prettify=False,
16
+ ):
8
17
  """A decorator for generating compute kernels.
9
18
 
10
19
  :param func: The function to be compiled.
11
20
  :param caller: Who will call the compute kernel.
12
21
  :param kernel_name: The name for the generated kernel.
22
+ :param num_warps: The number of warps to use.
23
+ :param num_stages: The number of pipeline stages.
24
+ :param max_num_configs: The maximum number of auto-tuning
25
+ configurations to use.
13
26
  :param _prettify: Whether to prettify the generated code.
14
27
  :return: A handle to the compute kernel.
15
28
 
@@ -20,7 +33,15 @@ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
20
33
  """
21
34
 
22
35
  def wrapper(func):
23
- return JIT(func, caller=caller, kernel_name=kernel_name, _prettify=_prettify)()
36
+ return JIT(
37
+ func,
38
+ caller=caller,
39
+ kernel_name=kernel_name,
40
+ num_warps=num_warps,
41
+ num_stages=num_stages,
42
+ max_num_configs=max_num_configs,
43
+ _prettify=_prettify,
44
+ )()
24
45
 
25
46
  if func is None:
26
47
  return wrapper
@@ -29,7 +50,16 @@ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
29
50
 
30
51
 
31
52
  class JIT:
32
- def __init__(self, func, caller, kernel_name, _prettify=False):
53
+ def __init__(
54
+ self,
55
+ func,
56
+ caller,
57
+ kernel_name,
58
+ num_warps,
59
+ num_stages,
60
+ max_num_configs,
61
+ _prettify=False,
62
+ ):
33
63
  self.func = func
34
64
 
35
65
  self._caller = caller
@@ -39,12 +69,24 @@ class JIT:
39
69
  else:
40
70
  self._kernel_name = func.__name__
41
71
 
72
+ self._num_warps = num_warps
73
+
74
+ self._num_stages = num_stages
75
+
76
+ self._max_num_configs = max_num_configs
77
+
42
78
  self._prettify = _prettify
43
79
 
44
80
  def __call__(self):
45
81
  code_generator = CodeGenerator()
46
82
  source_file = code_generator(
47
- self.func, self._caller, self._kernel_name, self._prettify
83
+ self.func,
84
+ self._caller,
85
+ self._kernel_name,
86
+ self._num_warps,
87
+ self._num_stages,
88
+ self._max_num_configs,
89
+ self._prettify,
48
90
  )
49
91
  module = type(self)._import_from_path(source_file, source_file)
50
92
  module_vars = vars(module)
ninetoothed/language.py CHANGED
@@ -1,7 +1,11 @@
1
1
  import ast
2
2
 
3
+ from triton.language.extra import libdevice
4
+
3
5
  from ninetoothed.symbol import Symbol
4
6
 
7
+ __all__ = ["libdevice"]
8
+
5
9
  LANGUAGE = "ninetoothed.language"
6
10
 
7
11
 
ninetoothed/make.py CHANGED
@@ -2,6 +2,7 @@ import inspect
2
2
 
3
3
  from ninetoothed.aot import aot
4
4
  from ninetoothed.jit import jit
5
+ from ninetoothed.utils import calculate_default_configs
5
6
 
6
7
 
7
8
  def make(
@@ -11,8 +12,9 @@ def make(
11
12
  caller="torch",
12
13
  kernel_name=None,
13
14
  output_dir=None,
14
- num_warps=4,
15
- num_stages=3,
15
+ num_warps=None,
16
+ num_stages=None,
17
+ max_num_configs=None,
16
18
  ):
17
19
  """Integrate the arrangement and the application of the tensors.
18
20
 
@@ -24,16 +26,33 @@ def make(
24
26
  :param output_dir: The directory to store the generated files.
25
27
  :param num_warps: The number of warps to use.
26
28
  :param num_stages: The number of pipeline stages.
29
+ :param max_num_configs: The maximum number of auto-tuning
30
+ configurations to use.
27
31
  :return: A handle to the compute kernel.
28
32
  """
29
33
 
34
+ default_num_warps, default_num_stages = calculate_default_configs()
35
+
36
+ if num_warps is None:
37
+ num_warps = default_num_warps
38
+
39
+ if num_stages is None:
40
+ num_stages = default_num_stages
41
+
30
42
  params = inspect.signature(application).parameters
31
43
  types = arrangement(*tensors)
32
44
  annotations = {param: type for param, type in zip(params, types)}
33
45
  application.__annotations__ = annotations
34
46
 
35
47
  if caller == "torch":
36
- return jit(application, caller=caller, kernel_name=kernel_name)
48
+ return jit(
49
+ application,
50
+ caller=caller,
51
+ kernel_name=kernel_name,
52
+ num_warps=num_warps,
53
+ num_stages=num_stages,
54
+ max_num_configs=max_num_configs,
55
+ )
37
56
 
38
57
  return aot(
39
58
  application,
ninetoothed/symbol.py CHANGED
@@ -12,9 +12,20 @@ class Symbol:
12
12
  :param expr: The expression used to construct the symbol.
13
13
  :param constexpr: Whether the symbol is a constexpr.
14
14
  :param mata: Whether the symbol is a meta.
15
+ :param lower_bound: The minimum value for the symbol's range.
16
+ :param upper_bound: The maximum value for the symbol's range.
17
+ :param power_of_two: Whether the value should be a power of two.
15
18
  """
16
19
 
17
- def __init__(self, expr, constexpr=None, meta=None):
20
+ def __init__(
21
+ self,
22
+ expr,
23
+ constexpr=None,
24
+ meta=None,
25
+ lower_bound=None,
26
+ upper_bound=None,
27
+ power_of_two=None,
28
+ ):
18
29
  if isinstance(expr, type(self)):
19
30
  self._node = expr._node
20
31
  return
@@ -43,6 +54,40 @@ class Symbol:
43
54
  if constexpr:
44
55
  self._node.id = naming.make_constexpr(self._node.id)
45
56
 
57
+ self._node.symbol = self
58
+
59
+ DEFAULT_LOWER_BOUND_FOR_META_SYMBOLS = 2**5
60
+ DEFAULT_UPPER_BOUND_FOR_META_SYMBOLS = 2**10
61
+ DEFAULT_POWER_OF_TWO_FOR_META_SYMBOLS = True
62
+
63
+ DEFAULT_LOWER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS = 1
64
+ DEFAULT_UPPER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS = 2**20
65
+ DEFAULT_POWER_OF_TWO_FOR_NON_META_CONSTEXPR_SYMBOLS = False
66
+
67
+ if lower_bound is not None:
68
+ self.lower_bound = lower_bound
69
+ else:
70
+ if meta:
71
+ self.lower_bound = DEFAULT_LOWER_BOUND_FOR_META_SYMBOLS
72
+ elif constexpr:
73
+ self.lower_bound = DEFAULT_LOWER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS
74
+
75
+ if upper_bound is not None:
76
+ self.upper_bound = upper_bound
77
+ else:
78
+ if meta:
79
+ self.upper_bound = DEFAULT_UPPER_BOUND_FOR_META_SYMBOLS
80
+ elif constexpr:
81
+ self.upper_bound = DEFAULT_UPPER_BOUND_FOR_NON_META_CONSTEXPR_SYMBOLS
82
+
83
+ if power_of_two is not None:
84
+ self.power_of_two = power_of_two
85
+ else:
86
+ if meta:
87
+ self.power_of_two = DEFAULT_POWER_OF_TWO_FOR_META_SYMBOLS
88
+ elif constexpr:
89
+ self.power_of_two = DEFAULT_POWER_OF_TWO_FOR_NON_META_CONSTEXPR_SYMBOLS
90
+
46
91
  def __eq__(self, other):
47
92
  if isinstance(self._node, ast.Constant):
48
93
  if isinstance(other, Symbol) and isinstance(other._node, ast.Constant):
@@ -155,7 +200,7 @@ class Symbol:
155
200
  def visit_Name(self, node):
156
201
  self.generic_visit(node)
157
202
 
158
- self.names.add(node.id)
203
+ self.names.add(node.symbol)
159
204
 
160
205
  name_collector = NameCollector()
161
206
 
@@ -179,6 +224,24 @@ class Symbol:
179
224
  return isinstance(object, Symbol) and isinstance(object.node, ast.Name)
180
225
 
181
226
 
227
+ def block_size(lower_bound=None, upper_bound=None):
228
+ """Create a block size symbol that serves as a meta-parameter.
229
+
230
+ :param lower_bound: The lower bound for the block size's range.
231
+ :param upper_bound: The upper bound for the block size's range.
232
+ :return: A block size symbol that serves as a meta-parameter.
233
+ """
234
+
235
+ name = naming.auto_generate(f"BLOCK_SIZE_{block_size._num_block_sizes}")
236
+
237
+ block_size._num_block_sizes += 1
238
+
239
+ return Symbol(name, meta=True, lower_bound=lower_bound, upper_bound=upper_bound)
240
+
241
+
242
+ block_size._num_block_sizes = 0
243
+
244
+
182
245
  class _FindAndReplacer(ast.NodeTransformer):
183
246
  def __init__(self, targets, replacement):
184
247
  self._targets_unparsed = tuple(
ninetoothed/tensor.py CHANGED
@@ -14,7 +14,7 @@ class Tensor:
14
14
  :param dtype: The element type of the tensor.
15
15
  :param strides: The strides of the tensor.
16
16
  :param other: The values for out-of-bounds positions.
17
- :param constexpr_shape: Whether the sizes are constexpr.
17
+ :param shape_options: The options for configuring shape symbols.
18
18
  :param name: The name of the tensor.
19
19
  :param source: For internal use only.
20
20
  :param source_dims: For internal use only.
@@ -31,7 +31,7 @@ class Tensor:
31
31
  dtype=None,
32
32
  strides=None,
33
33
  other=None,
34
- constexpr_shape=None,
34
+ shape_options=None,
35
35
  name=None,
36
36
  source=None,
37
37
  source_dims=None,
@@ -48,9 +48,20 @@ class Tensor:
48
48
  self.name = naming.auto_generate(f"tensor_{type(self).num_instances}")
49
49
 
50
50
  if ndim is not None:
51
+ if shape_options is None:
52
+ shape_options = tuple({} for _ in range(ndim))
53
+
54
+ if isinstance(shape_options, dict):
55
+ shape_options = tuple(shape_options for _ in range(ndim))
56
+
57
+ shape_options = tuple(
58
+ size_options if size_options is not None else {}
59
+ for size_options in shape_options
60
+ )
61
+
51
62
  self.shape = (
52
- Symbol(self.size_string(i), constexpr=constexpr_shape)
53
- for i in range(ndim)
63
+ Symbol(self.size_string(i), **size_options)
64
+ for i, size_options in zip(range(ndim), shape_options)
54
65
  )
55
66
  self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
56
67
  else:
@@ -364,10 +375,10 @@ class Tensor:
364
375
 
365
376
  def names(self):
366
377
  if self.ndim == 0:
367
- return {self.source.name}
378
+ return {Symbol(self.source.name)}
368
379
 
369
380
  return (
370
- {self.source.pointer_string()}
381
+ {Symbol(self.source.pointer_string())}
371
382
  | {
372
383
  name
373
384
  for value in itertools.chain(self.shape, self.strides)
ninetoothed/utils.py ADDED
@@ -0,0 +1,12 @@
1
+ import triton
2
+
3
+
4
+ def calculate_default_configs():
5
+ device = triton.runtime.driver.active.get_current_device()
6
+ properties = triton.runtime.driver.active.utils.get_device_properties(device)
7
+ max_shared_mem = properties["max_shared_mem"]
8
+
9
+ num_warps = 8
10
+ num_stages = max_shared_mem // 2**15
11
+
12
+ return num_warps, num_stages
@@ -10,8 +10,6 @@ def visualize(tensor, color=None, save_path=None):
10
10
  :param color: The color to be used for visualization.
11
11
  :param save_path: The path where the visualization should be saved.
12
12
  """
13
- outline_width = 0.1
14
- plt.rcParams["lines.linewidth"] = 72 * outline_width
15
13
 
16
14
  if color is None:
17
15
  color = f"C{visualize.count}"
@@ -21,6 +19,24 @@ def visualize(tensor, color=None, save_path=None):
21
19
  width = max_pos_y + 1
22
20
  height = max_pos_x + 1
23
21
 
22
+ _, ax = _prepare_figure_and_axes(width, height)
23
+
24
+ _visualize_tensor(ax, tensor, 0, 0, color)
25
+
26
+ plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
27
+
28
+ plt.close()
29
+
30
+ visualize.count += 1
31
+
32
+
33
+ visualize.count = 0
34
+
35
+
36
+ def _prepare_figure_and_axes(width, height):
37
+ outline_width = 0.1
38
+ plt.rcParams["lines.linewidth"] = 72 * outline_width
39
+
24
40
  fig = plt.figure(figsize=(width + outline_width, height + outline_width))
25
41
 
26
42
  h = (Size.Fixed(0), Size.Fixed(width + outline_width))
@@ -41,16 +57,7 @@ def visualize(tensor, color=None, save_path=None):
41
57
  plt.xlim((-half_outline_width, width + half_outline_width))
42
58
  plt.ylim((-half_outline_width, height + half_outline_width))
43
59
 
44
- _visualize_tensor(ax, tensor, 0, 0, color)
45
-
46
- plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
47
-
48
- plt.close()
49
-
50
- visualize.count += 1
51
-
52
-
53
- visualize.count = 0
60
+ return fig, ax
54
61
 
55
62
 
56
63
  def _visualize_tensor(ax, tensor, x, y, color, level_spacing=4):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.15.1
3
+ Version: 0.17.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
10
10
  Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
+ Requires-Dist: sympy>=1.13.0
13
14
  Requires-Dist: triton>=3.0.0
14
15
  Provides-Extra: all
15
16
  Requires-Dist: matplotlib>=3.9.0; extra == 'all'
@@ -0,0 +1,18 @@
1
+ ninetoothed/__init__.py,sha256=F2bxRNhzcGdtADA8RehTuf-QK0xnxno8kxvr6H2L5Tg,552
2
+ ninetoothed/aot.py,sha256=b7ykTC5roe_xg3NkZv6VyInBrEiNRwjpixCULUPRuEg,6506
3
+ ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
4
+ ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
5
+ ninetoothed/generation.py,sha256=wf8BL-x0PR6rG-9OSpgIZi8LtsIdFbqRUFiQFE5FIno,38107
6
+ ninetoothed/jit.py,sha256=CpeSkO_zUe9DwtTJ2K2H7Bwpx-FvIHfrgzOcEosfpek,2946
7
+ ninetoothed/language.py,sha256=ERiA4dpwiow2AT2xFeFWYg1KqlnBo6xxPGp8VZrP0Lk,574
8
+ ninetoothed/make.py,sha256=fQKuRJL7HC2iGTAN323mlIWXz9Z3jotIoN68ur29Qlw,1834
9
+ ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
10
+ ninetoothed/symbol.py,sha256=lJo3NL2-T7tKbKjb6MCRLMemN94mqS3bIiG943P0Mbo,7454
11
+ ninetoothed/tensor.py,sha256=gQEzHTcXqZVBFLc2YRfXTKxjxPWMxWN7fNl2BCfJwMs,14782
12
+ ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
13
+ ninetoothed/utils.py,sha256=mtRXABBVPnlgd2n1REh9oB3s_5bUsKhd3iwu3oJ5DSQ,338
14
+ ninetoothed/visualization.py,sha256=oc3cA5qqT66_RoAs5D681SCxR5E5wgFwk95ZefdSfZU,3794
15
+ ninetoothed-0.17.0.dist-info/METADATA,sha256=_V2M45nT4Yin-zs7hq5-yHlN6KwV5_zcA8afwXP8S-Q,7340
16
+ ninetoothed-0.17.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ ninetoothed-0.17.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
+ ninetoothed-0.17.0.dist-info/RECORD,,
@@ -1,17 +0,0 @@
1
- ninetoothed/__init__.py,sha256=zGaZiUzwJZ2jfwLxp7lT8ll_V5ngP5QYrfVbapftbCY,522
2
- ninetoothed/aot.py,sha256=5P9s-KAA7xNNdK8_fbCZEIteQlbaB_1wOl8_rEBQg9U,6128
3
- ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
4
- ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
5
- ninetoothed/generation.py,sha256=Gmeh9OPmWZmF9CUY-UIIBPi-SOjFCxZjvXNwqX3uD84,30963
6
- ninetoothed/jit.py,sha256=0MFbFIODtw-bxuOC7WByxiVtQMeyvZkoDxvfAZ9rIFQ,2120
7
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
8
- ninetoothed/make.py,sha256=wRr3JwGt5E2OCquq_nzBZljdW-AJPOqH49cM08gwl4A,1287
9
- ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
10
- ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
11
- ninetoothed/tensor.py,sha256=ByTnoeqxD9lXprvy1DDp5L-zU2up52-jop9AAUrSTYk,14347
12
- ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
13
- ninetoothed/visualization.py,sha256=zlMH-0WplaboePGzcbpcj4UovpX0k2r4SysSPsNS4r4,3674
14
- ninetoothed-0.15.1.dist-info/METADATA,sha256=6RA1-6fYFfSTJnnwsRoVb4yIRrn4kfLhN47GNvmGji0,7311
15
- ninetoothed-0.15.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- ninetoothed-0.15.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
17
- ninetoothed-0.15.1.dist-info/RECORD,,