ninetoothed 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -5,7 +5,8 @@ import inspect
5
5
  import itertools
6
6
  import math
7
7
  import tempfile
8
- import textwrap
8
+
9
+ import triton
9
10
 
10
11
  from ninetoothed.language import attribute, call
11
12
  from ninetoothed.symbol import Symbol
@@ -33,8 +34,7 @@ class JIT:
33
34
  ):
34
35
  return type(self).handles[source_file][source_line]
35
36
 
36
- source = textwrap.dedent(inspect.getsource(self.func))
37
- tree = ast.parse(source)
37
+ tree = self._get_tree()
38
38
 
39
39
  CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
40
40
  Tritonizer().visit(tree)
@@ -56,15 +56,7 @@ class JIT:
56
56
  namespace = {}
57
57
  exec(code, namespace)
58
58
 
59
- class Handle:
60
- def __init__(self, kernel, launch):
61
- self._kernel = kernel
62
- self._launch = launch
63
-
64
- def __call__(self, *args, **kwargs):
65
- return self._launch(*args, **kwargs)
66
-
67
- handle = Handle(
59
+ handle = _Handle(
68
60
  namespace[self.func.__name__],
69
61
  namespace[f"launch_{self.func.__name__}"],
70
62
  )
@@ -73,6 +65,15 @@ class JIT:
73
65
 
74
66
  return handle
75
67
 
68
+ def _get_tree(self):
69
+ module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
70
+
71
+ _AliasRestorer().visit(module)
72
+ finder = _FunctionDefFinder(self.func.__name__)
73
+ finder.visit(module)
74
+
75
+ return ast.Module(body=[finder.result], type_ignores=[])
76
+
76
77
 
77
78
  class CodeGenerator(ast.NodeTransformer):
78
79
  def __init__(self, context):
@@ -100,6 +101,29 @@ class CodeGenerator(ast.NodeTransformer):
100
101
 
101
102
  self.generic_visit(node)
102
103
 
104
+ for arg in self._args:
105
+ if not isinstance(arg, Tensor):
106
+ continue
107
+
108
+ offsets = arg.offsets()
109
+
110
+ initializations = {
111
+ type(self)._name_for_offsets(arg, dim): offs
112
+ for dim, offs in enumerate(offsets)
113
+ } | {
114
+ type(self)._name_for_pointers(arg): arg.original.pointer_string()
115
+ + sum(
116
+ type(self)._name_for_offsets(arg, dim)[
117
+ type(self)._generate_slices(arg, dim)
118
+ ]
119
+ * stride
120
+ for dim, stride in enumerate(arg.original.strides)
121
+ )
122
+ }
123
+
124
+ for target, value in reversed(initializations.items()):
125
+ node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
126
+
103
127
  return node
104
128
 
105
129
  def visit_arguments(self, node):
@@ -136,14 +160,12 @@ class CodeGenerator(ast.NodeTransformer):
136
160
  value = self._context[node.value.id]
137
161
 
138
162
  if isinstance(value, Tensor):
139
- if isinstance(node.slice, ast.Tuple):
140
- indices = value.indices() + tuple(node.slice.elts)
141
- else:
142
- indices = value.indices() + (node.slice,)
143
- offsets = value.offsets(indices)
144
- pointers = value.pointers(offsets)
145
-
146
- return call("load", pointers).node
163
+ return type(self)._generate_load(
164
+ value,
165
+ intermediate_indices=node.slice.elts
166
+ if isinstance(node.slice, ast.Tuple)
167
+ else (node.slice,),
168
+ )
147
169
 
148
170
  self.generic_visit(node)
149
171
 
@@ -166,7 +188,7 @@ class CodeGenerator(ast.NodeTransformer):
166
188
  self.generic_visit(node)
167
189
 
168
190
  if node.id in self._context and isinstance(node.ctx, ast.Load):
169
- return call("load", self._context[node.id].pointers().node).node
191
+ return type(self)._generate_load(self._context[node.id])
170
192
 
171
193
  return node
172
194
 
@@ -178,11 +200,7 @@ class CodeGenerator(ast.NodeTransformer):
178
200
  self.generic_visit(node)
179
201
 
180
202
  return ast.Expr(
181
- call(
182
- "store",
183
- self._context[target.id].pointers().node,
184
- node.value,
185
- ).node
203
+ type(self)._generate_store(self._context[target.id], node.value)
186
204
  )
187
205
  elif (
188
206
  isinstance(target, ast.Subscript)
@@ -195,20 +213,14 @@ class CodeGenerator(ast.NodeTransformer):
195
213
  if isinstance(value, Tensor):
196
214
  self.generic_visit(node)
197
215
 
198
- indices = value.indices() + tuple(
199
- target.slice.elts
200
- if isinstance(target.slice, ast.Tuple)
201
- else target.slice
202
- )
203
- offsets = value.offsets(indices)
204
- pointers = value.pointers(offsets)
205
-
206
216
  return ast.Expr(
207
- call(
208
- "store",
209
- pointers.node,
217
+ type(self)._generate_store(
218
+ value,
210
219
  node.value,
211
- ).node
220
+ intermediate_indices=target.slice.elts
221
+ if isinstance(target.slice, ast.Tuple)
222
+ else (target.slice,),
223
+ )
212
224
  )
213
225
 
214
226
  self.generic_visit(node)
@@ -216,6 +228,13 @@ class CodeGenerator(ast.NodeTransformer):
216
228
  return node
217
229
 
218
230
  def _generate_autotune(self, params, meta):
231
+ device = triton.runtime.driver.active.get_current_device()
232
+ properties = triton.runtime.driver.active.utils.get_device_properties(device)
233
+ max_shared_mem = properties["max_shared_mem"]
234
+
235
+ num_warps = 8
236
+ num_stages = max_shared_mem // 2**15
237
+
219
238
  configs = [
220
239
  ast.Call(
221
240
  func=ast.Attribute(
@@ -229,7 +248,10 @@ class CodeGenerator(ast.NodeTransformer):
229
248
  values=[ast.Constant(value=value) for value in values],
230
249
  )
231
250
  ],
232
- keywords=[],
251
+ keywords=[
252
+ ast.keyword(arg="num_warps", value=ast.Constant(value=num_warps)),
253
+ ast.keyword(arg="num_stages", value=ast.Constant(value=num_stages)),
254
+ ],
233
255
  )
234
256
  for values in itertools.product(self._POWER_OF_TWOS, repeat=len(meta))
235
257
  if self._MIN_PRODUCT <= math.prod(values) <= self._MAX_PRODUCT
@@ -256,7 +278,7 @@ class CodeGenerator(ast.NodeTransformer):
256
278
  elts=[
257
279
  ast.Constant(value=param)
258
280
  for param in params
259
- if not Tensor.is_pointer(param)
281
+ if not Tensor.pointer_pattern().fullmatch(param)
260
282
  ],
261
283
  ctx=ast.Load(),
262
284
  ),
@@ -269,7 +291,7 @@ class CodeGenerator(ast.NodeTransformer):
269
291
  name=f"launch_{self._func_def.name}",
270
292
  args=ast.arguments(
271
293
  posonlyargs=[],
272
- args=[ast.arg(arg.name) for arg in self._args],
294
+ args=[ast.arg(arg.original.name) for arg in self._args],
273
295
  kwonlyargs=[],
274
296
  defaults=[],
275
297
  ),
@@ -316,6 +338,77 @@ class CodeGenerator(ast.NodeTransformer):
316
338
 
317
339
  return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
318
340
 
341
+ @staticmethod
342
+ def _generate_load(tensor, intermediate_indices=()):
343
+ pointers, mask = CodeGenerator._generate_pointers_and_mask(
344
+ tensor, intermediate_indices
345
+ )
346
+ other = CodeGenerator._generate_other(tensor)
347
+
348
+ return call("load", pointers, mask=mask, other=other).node
349
+
350
+ @staticmethod
351
+ def _generate_store(tensor, value, intermediate_indices=()):
352
+ pointers, mask = CodeGenerator._generate_pointers_and_mask(
353
+ tensor, intermediate_indices
354
+ )
355
+
356
+ return call("store", pointers, value, mask=mask).node
357
+
358
+ @staticmethod
359
+ def _generate_pointers_and_mask(tensor, intermediate_indices):
360
+ intermediate_offsets = CodeGenerator._generate_intermediate_offsets(
361
+ tensor, intermediate_indices
362
+ )
363
+ offsets = [
364
+ CodeGenerator._name_for_offsets(tensor, dim) + intermediate_offsets[dim]
365
+ for dim in range(tensor.original.ndim)
366
+ ]
367
+ pointers = CodeGenerator._name_for_pointers(tensor) + sum(
368
+ map(lambda x, y: x * y, intermediate_offsets, tensor.original.strides)
369
+ )
370
+ mask = functools.reduce(
371
+ lambda x, y: x & y,
372
+ (
373
+ offs[CodeGenerator._generate_slices(tensor, dim)] < size
374
+ for dim, (offs, size) in enumerate(zip(offsets, tensor.original.shape))
375
+ ),
376
+ )
377
+
378
+ return pointers, mask
379
+
380
+ @staticmethod
381
+ def _generate_other(tensor):
382
+ other = tensor.original.other
383
+
384
+ if isinstance(other, float) and not math.isfinite(other):
385
+ return f"float('{other}')"
386
+
387
+ return other
388
+
389
+ @staticmethod
390
+ def _generate_slices(tensor, dim):
391
+ return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
392
+
393
+ @staticmethod
394
+ def _generate_intermediate_offsets(tensor, intermediate_indices):
395
+ return tuple(
396
+ offs
397
+ for offs in tensor.offsets(
398
+ [0 for _ in range(tensor.ndim)]
399
+ + list(intermediate_indices)
400
+ + [0 for _ in range(tensor.inmost().ndim)]
401
+ )
402
+ )
403
+
404
+ @staticmethod
405
+ def _name_for_pointers(tensor):
406
+ return Symbol(f"{tensor.original.name}_pointers")
407
+
408
+ @staticmethod
409
+ def _name_for_offsets(tensor, dim):
410
+ return Symbol(f"{tensor.original.name}_offsets_{dim}")
411
+
319
412
 
320
413
  class Tritonizer(ast.NodeTransformer):
321
414
  def visit_Module(self, node):
@@ -329,8 +422,8 @@ class Tritonizer(ast.NodeTransformer):
329
422
  def visit_Name(self, node):
330
423
  self.generic_visit(node)
331
424
 
332
- if node.id == "ninetoothed":
333
- node.id = "triton"
425
+ if node.id == "ninetoothed" or "ninetoothed." in node.id:
426
+ node.id = node.id.replace("ninetoothed", "triton")
334
427
 
335
428
  return node
336
429
 
@@ -348,3 +441,73 @@ class Tritonizer(ast.NodeTransformer):
348
441
  )
349
442
 
350
443
  return node
444
+
445
+
446
+ class _Handle:
447
+ def __init__(self, kernel, launch):
448
+ self._kernel = kernel
449
+ self._launch = launch
450
+
451
+ def __call__(self, *args, **kwargs):
452
+ return self._launch(*args, **kwargs)
453
+
454
+
455
+ class _AliasRestorer(ast.NodeTransformer):
456
+ def __init__(self):
457
+ super().__init__()
458
+
459
+ self._aliases = {}
460
+ self._redefined = set()
461
+
462
+ def visit_Import(self, node):
463
+ for alias in node.names:
464
+ if alias.asname:
465
+ self._aliases[alias.asname] = alias.name
466
+
467
+ return node
468
+
469
+ def visit_ImportFrom(self, node):
470
+ for alias in node.names:
471
+ full_name = f"{node.module}.{alias.name}"
472
+ if alias.asname:
473
+ self._aliases[alias.asname] = full_name
474
+
475
+ return node
476
+
477
+ def visit_Assign(self, node):
478
+ for target in node.targets:
479
+ if isinstance(target, ast.Name):
480
+ self._redefined.add(target.id)
481
+
482
+ return self.generic_visit(node)
483
+
484
+ def visit_FunctionDef(self, node):
485
+ original_redefined = self._redefined.copy()
486
+
487
+ self.generic_visit(node)
488
+
489
+ self._redefined = original_redefined
490
+
491
+ return node
492
+
493
+ def visit_Name(self, node):
494
+ if node.id in self._redefined:
495
+ return node
496
+
497
+ if node.id in self._aliases:
498
+ return ast.Name(id=self._aliases[node.id], ctx=node.ctx)
499
+
500
+ return node
501
+
502
+
503
+ class _FunctionDefFinder(ast.NodeVisitor):
504
+ def __init__(self, name):
505
+ self._name = name
506
+
507
+ self.result = None
508
+
509
+ def visit_FunctionDef(self, node):
510
+ if node.name == self._name:
511
+ self.result = node
512
+
513
+ self.generic_visit(node)
ninetoothed/language.py CHANGED
@@ -10,7 +10,10 @@ def call(func, *args, **kwargs):
10
10
  ast.Call(
11
11
  func=attribute(func).node,
12
12
  args=[Symbol(arg).node for arg in args],
13
- keywords=[(kwarg, Symbol(kwargs[kwarg]).node) for kwarg in kwargs],
13
+ keywords=[
14
+ ast.keyword(arg=kwarg, value=Symbol(kwargs[kwarg]).node)
15
+ for kwarg in kwargs
16
+ ],
14
17
  )
15
18
  )
16
19
 
ninetoothed/symbol.py CHANGED
@@ -34,37 +34,80 @@ class Symbol:
34
34
  self._node.id = type(self)._create_constexpr(self._node.id)
35
35
 
36
36
  def __add__(self, other):
37
- return type(self)(
38
- ast.BinOp(left=self._node, op=ast.Add(), right=type(self)(other)._node)
39
- )
37
+ other = type(self)(other)
38
+
39
+ if isinstance(self._node, ast.Constant) and self._node.value == 0:
40
+ return other
41
+
42
+ if isinstance(other._node, ast.Constant) and other._node.value == 0:
43
+ return self
44
+
45
+ return type(self)(ast.BinOp(left=self._node, op=ast.Add(), right=other._node))
40
46
 
41
47
  def __radd__(self, other):
42
48
  return self.__add__(other)
43
49
 
44
50
  def __mul__(self, other):
45
- return type(self)(
46
- ast.BinOp(left=self._node, op=ast.Mult(), right=type(self)(other)._node)
47
- )
51
+ other = type(self)(other)
52
+
53
+ if isinstance(self._node, ast.Constant) and self._node.value == 0:
54
+ return type(self)(0)
55
+
56
+ if isinstance(other._node, ast.Constant) and other._node.value == 0:
57
+ return type(self)(0)
58
+
59
+ if isinstance(self._node, ast.Constant) and self._node.value == 1:
60
+ return other
61
+
62
+ if isinstance(other._node, ast.Constant) and other._node.value == 1:
63
+ return self
64
+
65
+ return type(self)(ast.BinOp(left=self._node, op=ast.Mult(), right=other._node))
48
66
 
49
67
  def __rmul__(self, other):
50
68
  return self.__mul__(other)
51
69
 
52
70
  def __floordiv__(self, other):
71
+ other = type(self)(other)
72
+
73
+ if isinstance(other._node, ast.Constant) and other._node.value == 1:
74
+ return self
75
+
53
76
  return type(self)(
54
- ast.BinOp(left=self._node, op=ast.FloorDiv(), right=type(self)(other)._node)
77
+ ast.BinOp(left=self._node, op=ast.FloorDiv(), right=other._node)
55
78
  )
56
79
 
57
80
  def __mod__(self, other):
81
+ other = type(self)(other)
82
+
83
+ return type(self)(ast.BinOp(left=self._node, op=ast.Mod(), right=other._node))
84
+
85
+ def __lt__(self, other):
86
+ other = type(self)(other)
87
+
88
+ return type(self)(
89
+ ast.Compare(left=self._node, ops=[ast.Lt()], comparators=[other._node])
90
+ )
91
+
92
+ def __and__(self, other):
93
+ other = type(self)(other)
94
+
58
95
  return type(self)(
59
- ast.BinOp(left=self._node, op=ast.Mod(), right=type(self)(other)._node)
96
+ ast.BinOp(left=self._node, op=ast.BitAnd(), right=other._node)
60
97
  )
61
98
 
99
+ def __rand__(self, other):
100
+ return self.__and__(other)
101
+
62
102
  def __getitem__(self, key):
63
103
  return type(self)(ast.Subscript(value=self._node, slice=type(self)(key)._node))
64
104
 
65
105
  def __repr__(self):
66
106
  return ast.unparse(self._node)
67
107
 
108
+ def find_and_replace(self, target, replacement):
109
+ _FindAndReplacer(target.node, replacement.node).visit(self._node)
110
+
68
111
  def names(self):
69
112
  class NameCollector(ast.NodeVisitor):
70
113
  def __init__(self):
@@ -107,3 +150,15 @@ class Symbol:
107
150
  @staticmethod
108
151
  def _create_meta(name):
109
152
  return f"_ninetoothed_meta_{name}"
153
+
154
+
155
+ class _FindAndReplacer(ast.NodeTransformer):
156
+ def __init__(self, target, replacement):
157
+ self._target_id = target.id
158
+ self._replacement = replacement
159
+
160
+ def visit_Name(self, node):
161
+ if node.id == self._target_id:
162
+ return self._replacement
163
+
164
+ return self.generic_visit(node)
ninetoothed/tensor.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import itertools
2
+ import re
2
3
 
3
4
  from ninetoothed.language import call
4
5
  from ninetoothed.symbol import Symbol
@@ -7,19 +8,24 @@ from ninetoothed.symbol import Symbol
7
8
  class Tensor:
8
9
  num_instances = 0
9
10
 
10
- def __init__(self, ndim=None, shape=None, dtype=None, strides=None, name=None):
11
+ def __init__(
12
+ self,
13
+ ndim=None,
14
+ shape=None,
15
+ dtype=None,
16
+ strides=None,
17
+ other=None,
18
+ original=None,
19
+ ):
11
20
  type(self).num_instances += 1
12
21
 
13
22
  self.dtype = dtype
14
23
 
15
- if name is not None:
16
- self.name = name
17
- else:
18
- self.name = f"tensor_{type(self).num_instances}"
24
+ self.name = f"tensor_{type(self).num_instances}"
19
25
 
20
26
  if ndim is not None:
21
- self.shape = [Symbol(f"{self.name}_size_{i}") for i in range(ndim)]
22
- self.strides = [Symbol(f"{self.name}_stride_{i}") for i in range(ndim)]
27
+ self.shape = [Symbol(self.size_string(i)) for i in range(ndim)]
28
+ self.strides = [Symbol(self.stride_string(i)) for i in range(ndim)]
23
29
  else:
24
30
  self.shape = shape
25
31
 
@@ -28,6 +34,13 @@ class Tensor:
28
34
  else:
29
35
  self.strides = self._calculate_default_strides(shape)
30
36
 
37
+ self.other = other
38
+
39
+ if original is not None:
40
+ self.original = original
41
+ else:
42
+ self.original = self
43
+
31
44
  def tile(self, tile_shape, tile_strides=None):
32
45
  if tile_strides is None:
33
46
  tile_strides = [1 for _ in tile_shape]
@@ -59,10 +72,10 @@ class Tensor:
59
72
  shape=inner_shape,
60
73
  dtype=self.dtype,
61
74
  strides=inner_strides,
62
- name=self.name,
75
+ original=self.original,
63
76
  ),
64
77
  strides=outer_strides,
65
- name=self.name,
78
+ original=self.original,
66
79
  )
67
80
 
68
81
  def expand(self, shape):
@@ -77,12 +90,21 @@ class Tensor:
77
90
  stride if new_size == -1 else 0
78
91
  for new_size, stride in zip(shape, self.strides)
79
92
  ],
80
- name=self.name,
93
+ original=self.original,
94
+ )
95
+
96
+ def squeeze(self, dim):
97
+ # TODO: Add error handling.
98
+ return type(self)(
99
+ shape=[size for i, size in enumerate(self.shape) if dim != i],
100
+ dtype=self.dtype,
101
+ strides=[stride for i, stride in enumerate(self.strides) if dim != i],
102
+ original=self.original,
81
103
  )
82
104
 
83
105
  def names(self):
84
106
  return (
85
- {self._pointer()}
107
+ {self.original.pointer_string()}
86
108
  | {
87
109
  name
88
110
  for value in itertools.chain(self.shape, self.strides)
@@ -92,34 +114,31 @@ class Tensor:
92
114
  | (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
93
115
  )
94
116
 
95
- def pointers(self, offsets=None):
96
- if offsets is None:
97
- offsets = self.offsets()
98
-
99
- return self._pointer() + offsets
100
-
101
117
  def offsets(self, indices=None):
102
118
  if indices is None:
103
119
  indices = self.indices()
104
120
 
105
- if not isinstance(self.dtype, type(self)):
106
- if indices:
107
- raise IndexError("Incorrect number of indices.")
121
+ offsets = [[] for _ in range(self.original.ndim)]
122
+
123
+ curr = self
124
+ start = 0
125
+
126
+ while isinstance(curr, type(self)):
127
+ stop = start + curr.ndim
128
+ curr_indices = indices[start:stop]
129
+
130
+ for index, stride in zip(curr_indices, curr.strides):
131
+ for dim in self._dims_of(stride):
132
+ offsets[dim].append(index * stride)
108
133
 
109
- return sum(
110
- self.stride(idx)
111
- * call("arange", 0, self.size(idx))[
112
- tuple(slice(None) if i == idx else None for i in range(self.ndim()))
113
- ]
114
- for idx in range(self.ndim())
115
- )
134
+ start = stop
135
+ curr = curr.dtype
116
136
 
117
- outer_indices = indices[: self.ndim()]
118
- inner_indices = indices[self.ndim() :]
137
+ for dim in range(self.original.ndim):
138
+ offsets[dim] = sum(offsets[dim])
139
+ offsets[dim].find_and_replace(Symbol(self.original.strides[dim]), Symbol(1))
119
140
 
120
- return sum(
121
- index * stride for index, stride in zip(outer_indices, self.strides)
122
- ) + self.dtype.offsets(inner_indices)
141
+ return offsets
123
142
 
124
143
  def indices(self, index=None):
125
144
  if index is None:
@@ -127,14 +146,38 @@ class Tensor:
127
146
 
128
147
  indices = []
129
148
 
130
- for stride in type(self)(shape=self.shape, name=self.name).strides:
149
+ for stride in type(self)(shape=self.shape, original=self.original).strides:
131
150
  indices.append(index // stride)
132
151
  index %= stride
133
152
 
153
+ curr = self.dtype
154
+
155
+ while isinstance(curr.dtype, type(self)):
156
+ for _ in range(curr.ndim):
157
+ indices.append(0)
158
+
159
+ curr = curr.dtype
160
+
161
+ if isinstance(curr, type(self)):
162
+ for dim in range(curr.ndim):
163
+ indices.append(call("arange", 0, curr.shape[dim]))
164
+
134
165
  return tuple(indices)
135
166
 
136
- def ndim(self):
137
- return len(self.shape)
167
+ def inmost(self):
168
+ if not isinstance(self.dtype, type(self)):
169
+ return self
170
+
171
+ return self.dtype.inmost()
172
+
173
+ def pointer_string(self):
174
+ return f"{self.name}_pointer"
175
+
176
+ def size_string(self, dim):
177
+ return f"{self.name}_size_{dim}"
178
+
179
+ def stride_string(self, dim):
180
+ return f"{self.name}_stride_{dim}"
138
181
 
139
182
  def size(self, dim=None):
140
183
  if dim is None:
@@ -148,12 +191,31 @@ class Tensor:
148
191
 
149
192
  return self.strides[dim]
150
193
 
194
+ @property
195
+ def ndim(self):
196
+ return len(self.shape)
197
+
151
198
  @staticmethod
152
- def is_pointer(name):
153
- return name.endswith("_ptr")
199
+ def pointer_pattern():
200
+ return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
154
201
 
155
- def _pointer(self):
156
- return f"{self.name}_ptr"
202
+ @staticmethod
203
+ def size_pattern():
204
+ return re.compile(rf"({_identifier_pattern_raw_string()})_(size)_(.+)")
205
+
206
+ @staticmethod
207
+ def stride_pattern():
208
+ return re.compile(rf"({_identifier_pattern_raw_string()})_(stride)_(.+)")
209
+
210
+ def _dims_of(self, stride):
211
+ dims = set()
212
+ names = stride.names() if isinstance(stride, Symbol) else {stride}
213
+
214
+ for dim, original_stride in enumerate(self.original.strides):
215
+ if str(original_stride) in names:
216
+ dims.add(dim)
217
+
218
+ return dims
157
219
 
158
220
  @staticmethod
159
221
  def _calculate_default_strides(shape):
@@ -163,3 +225,7 @@ class Tensor:
163
225
  strides.append(size * strides[-1])
164
226
 
165
227
  return reversed(strides)
228
+
229
+
230
+ def _identifier_pattern_raw_string():
231
+ return r"[a-zA-Z_][a-zA-Z0-9_]*"
ninetoothed/torchifier.py CHANGED
@@ -1,23 +1,27 @@
1
1
  import ast
2
- import re
2
+
3
+ from ninetoothed.tensor import Tensor
3
4
 
4
5
 
5
6
  class Torchifier(ast.NodeTransformer):
6
7
  def visit_Name(self, node):
7
8
  self.generic_visit(node)
8
9
 
9
- pattern = re.compile(r"([a-zA-Z_][a-zA-Z0-9_]*)_(size|stride)_(.+)")
10
+ source = node.id
11
+
12
+ def repl(match):
13
+ return f"{match.group(1)}"
14
+
15
+ source = Tensor.pointer_pattern().sub(repl, source)
16
+
17
+ def repl(match):
18
+ return f"{match.group(1)}.{match.group(2)}({match.group(3)})"
10
19
 
11
- node.id = node.id.replace("_ptr", "")
20
+ source = Tensor.size_pattern().sub(repl, source)
21
+ source = Tensor.stride_pattern().sub(repl, source)
12
22
 
13
- if re.fullmatch(pattern, node.id):
14
- return ast.parse(
15
- pattern.sub(
16
- lambda match: f"{match.group(1)}.{match.group(2)}({match.group(3)})",
17
- node.id,
18
- ),
19
- mode="eval",
20
- ).body
23
+ if source != node.id:
24
+ return ast.parse(source, mode="eval").body
21
25
 
22
26
  return node
23
27
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ninetoothed
3
- Version: 0.1.1
3
+ Version: 0.3.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
10
10
  Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
+ Requires-Dist: triton>=3.0.0
13
14
  Description-Content-Type: text/markdown
14
15
 
15
16
  # NineToothed
@@ -26,7 +27,7 @@ We can use `pip` to install `ninetoothed`.
26
27
  pip install ninetoothed
27
28
  ```
28
29
 
29
- After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install `triton` and a deep learning framework supported by `ninetoothed`. For trial purposes, we recommend installing `triton` and `torch`.
30
+ After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install a deep learning framework supported by `ninetoothed`. For trial purposes, we recommend installing `torch`.
30
31
 
31
32
  ## Usage
32
33
 
@@ -64,16 +65,19 @@ c_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
64
65
  a_tiled = a_tiled.expand((-1, c_tiled.shape[1]))
65
66
  b_tiled = b_tiled.expand((c_tiled.shape[0], -1))
66
67
 
68
+ a_tiled.dtype = a_tiled.dtype.squeeze(0)
69
+ b_tiled.dtype = b_tiled.dtype.squeeze(1)
70
+
67
71
  @ninetoothed.jit
68
72
  def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
69
73
  accumulator = ninetoothed.language.zeros(
70
74
  c.shape, dtype=ninetoothed.language.float32
71
75
  )
72
- for k in range(a.shape[1]):
73
- accumulator = ninetoothed.language.dot(a[0, k], b[k, 0], accumulator)
76
+ for k in range(a.shape[0]):
77
+ accumulator += ninetoothed.language.dot(a[k], b[k])
74
78
  c = accumulator.to(ninetoothed.language.float16)
75
79
  ```
76
80
 
77
- For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$.
81
+ For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
78
82
 
79
83
  With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of B, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
@@ -0,0 +1,10 @@
1
+ ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
+ ninetoothed/jit.py,sha256=nhjZRi8_kcjWZX0eOrnxLlzJfVg5vn12f9oi0Er2ABE,15515
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/symbol.py,sha256=Bd54qcI8KQAX0JRE_wPXycswtdSofhZ6Rr5MtZcv9fo,4665
5
+ ninetoothed/tensor.py,sha256=_DrjOJ-pBvEbSNUvUoYJduLQXmuKgNcqhe4xUDMVoZw,6275
6
+ ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
7
+ ninetoothed-0.3.0.dist-info/METADATA,sha256=CqdtfdV0eHzSwxJmFpD2IG5d4WTc6RDlpqMZue4Ml2Q,6720
8
+ ninetoothed-0.3.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ ninetoothed-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
+ ninetoothed-0.3.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
- ninetoothed/jit.py,sha256=DdRdZ7DhfZwJeS7AcO_RhD9TZcCebKI55V4_6UHs3bo,10523
3
- ninetoothed/language.py,sha256=cSuTgi5OwmLFy-dy_AHGZzRm18wz01ByHQ2vioP1vTg,437
4
- ninetoothed/symbol.py,sha256=8BI4ekeLuUdHTEREvMMlAzwrJ93pqiCdSHGc38clBFA,3034
5
- ninetoothed/tensor.py,sha256=o_HLEuaBzojmbMLnbPGLcw4iqBI34TNdES3YLTagztE,4590
6
- ninetoothed/torchifier.py,sha256=JmIVQE8r0zr_RLExsRDOGNsMu0F7v6J_o22aWqlw81k,841
7
- ninetoothed-0.1.1.dist-info/METADATA,sha256=1Nv6Xcz7CrpEUrzAYH93bYVX8GfPtHwzj4yofeaoJro,5422
8
- ninetoothed-0.1.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- ninetoothed-0.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
- ninetoothed-0.1.1.dist-info/RECORD,,