ninetoothed 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -6,6 +6,8 @@ import itertools
6
6
  import math
7
7
  import tempfile
8
8
 
9
+ import triton
10
+
9
11
  from ninetoothed.language import attribute, call
10
12
  from ninetoothed.symbol import Symbol
11
13
  from ninetoothed.tensor import Tensor
@@ -103,13 +105,24 @@ class CodeGenerator(ast.NodeTransformer):
103
105
  if not isinstance(arg, Tensor):
104
106
  continue
105
107
 
106
- node.body.insert(
107
- 0,
108
- ast.Assign(
109
- targets=[Symbol(f"{arg.name}_ptrs").node],
110
- value=arg.pointers().node,
111
- ),
112
- )
108
+ offsets = arg.offsets()
109
+
110
+ initializations = {
111
+ type(self)._name_for_offsets(arg, dim): offs
112
+ for dim, offs in enumerate(offsets)
113
+ } | {
114
+ type(self)._name_for_pointers(arg): arg.original.pointer_string()
115
+ + sum(
116
+ type(self)._name_for_offsets(arg, dim)[
117
+ type(self)._generate_slices(arg, dim)
118
+ ]
119
+ * stride
120
+ for dim, stride in enumerate(arg.original.strides)
121
+ )
122
+ }
123
+
124
+ for target, value in reversed(initializations.items()):
125
+ node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
113
126
 
114
127
  return node
115
128
 
@@ -147,15 +160,13 @@ class CodeGenerator(ast.NodeTransformer):
147
160
  value = self._context[node.value.id]
148
161
 
149
162
  if isinstance(value, Tensor):
150
- pointers = type(self)._create_pointers(
163
+ return type(self)._generate_load(
151
164
  value,
152
- node.slice.elts
165
+ intermediate_indices=node.slice.elts
153
166
  if isinstance(node.slice, ast.Tuple)
154
167
  else (node.slice,),
155
168
  )
156
169
 
157
- return call("load", pointers).node
158
-
159
170
  self.generic_visit(node)
160
171
 
161
172
  return node
@@ -177,9 +188,7 @@ class CodeGenerator(ast.NodeTransformer):
177
188
  self.generic_visit(node)
178
189
 
179
190
  if node.id in self._context and isinstance(node.ctx, ast.Load):
180
- return call(
181
- "load", type(self)._create_pointers(self._context[node.id], ()).node
182
- ).node
191
+ return type(self)._generate_load(self._context[node.id])
183
192
 
184
193
  return node
185
194
 
@@ -191,11 +200,7 @@ class CodeGenerator(ast.NodeTransformer):
191
200
  self.generic_visit(node)
192
201
 
193
202
  return ast.Expr(
194
- call(
195
- "store",
196
- type(self)._create_pointers(self._context[target.id], ()).node,
197
- node.value,
198
- ).node
203
+ type(self)._generate_store(self._context[target.id], node.value)
199
204
  )
200
205
  elif (
201
206
  isinstance(target, ast.Subscript)
@@ -208,19 +213,14 @@ class CodeGenerator(ast.NodeTransformer):
208
213
  if isinstance(value, Tensor):
209
214
  self.generic_visit(node)
210
215
 
211
- pointers = type(self)._create_pointers(
212
- value,
213
- target.slice.elts
214
- if isinstance(target.slice, ast.Tuple)
215
- else (target.slice,),
216
- )
217
-
218
216
  return ast.Expr(
219
- call(
220
- "store",
221
- pointers.node,
217
+ type(self)._generate_store(
218
+ value,
222
219
  node.value,
223
- ).node
220
+ intermediate_indices=target.slice.elts
221
+ if isinstance(target.slice, ast.Tuple)
222
+ else (target.slice,),
223
+ )
224
224
  )
225
225
 
226
226
  self.generic_visit(node)
@@ -228,6 +228,13 @@ class CodeGenerator(ast.NodeTransformer):
228
228
  return node
229
229
 
230
230
  def _generate_autotune(self, params, meta):
231
+ device = triton.runtime.driver.active.get_current_device()
232
+ properties = triton.runtime.driver.active.utils.get_device_properties(device)
233
+ max_shared_mem = properties["max_shared_mem"]
234
+
235
+ num_warps = 8
236
+ num_stages = max_shared_mem // 2**15
237
+
231
238
  configs = [
232
239
  ast.Call(
233
240
  func=ast.Attribute(
@@ -241,7 +248,10 @@ class CodeGenerator(ast.NodeTransformer):
241
248
  values=[ast.Constant(value=value) for value in values],
242
249
  )
243
250
  ],
244
- keywords=[],
251
+ keywords=[
252
+ ast.keyword(arg="num_warps", value=ast.Constant(value=num_warps)),
253
+ ast.keyword(arg="num_stages", value=ast.Constant(value=num_stages)),
254
+ ],
245
255
  )
246
256
  for values in itertools.product(self._POWER_OF_TWOS, repeat=len(meta))
247
257
  if self._MIN_PRODUCT <= math.prod(values) <= self._MAX_PRODUCT
@@ -268,7 +278,7 @@ class CodeGenerator(ast.NodeTransformer):
268
278
  elts=[
269
279
  ast.Constant(value=param)
270
280
  for param in params
271
- if not Tensor.is_pointer(param)
281
+ if not Tensor.pointer_pattern().fullmatch(param)
272
282
  ],
273
283
  ctx=ast.Load(),
274
284
  ),
@@ -281,7 +291,7 @@ class CodeGenerator(ast.NodeTransformer):
281
291
  name=f"launch_{self._func_def.name}",
282
292
  args=ast.arguments(
283
293
  posonlyargs=[],
284
- args=[ast.arg(arg.name) for arg in self._args],
294
+ args=[ast.arg(arg.original.name) for arg in self._args],
285
295
  kwonlyargs=[],
286
296
  defaults=[],
287
297
  ),
@@ -329,12 +339,75 @@ class CodeGenerator(ast.NodeTransformer):
329
339
  return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
330
340
 
331
341
  @staticmethod
332
- def _create_pointers(tensor, indices):
333
- return Symbol(f"{tensor.name}_ptrs") + tensor.offsets(
334
- [0 for _ in range(tensor.ndim())]
335
- + list(indices)
336
- + [0 for _ in range(tensor.inmost().ndim())]
342
+ def _generate_load(tensor, intermediate_indices=()):
343
+ pointers, mask = CodeGenerator._generate_pointers_and_mask(
344
+ tensor, intermediate_indices
345
+ )
346
+ other = CodeGenerator._generate_other(tensor)
347
+
348
+ return call("load", pointers, mask=mask, other=other).node
349
+
350
+ @staticmethod
351
+ def _generate_store(tensor, value, intermediate_indices=()):
352
+ pointers, mask = CodeGenerator._generate_pointers_and_mask(
353
+ tensor, intermediate_indices
354
+ )
355
+
356
+ return call("store", pointers, value, mask=mask).node
357
+
358
+ @staticmethod
359
+ def _generate_pointers_and_mask(tensor, intermediate_indices):
360
+ intermediate_offsets = CodeGenerator._generate_intermediate_offsets(
361
+ tensor, intermediate_indices
337
362
  )
363
+ offsets = [
364
+ CodeGenerator._name_for_offsets(tensor, dim) + intermediate_offsets[dim]
365
+ for dim in range(tensor.original.ndim)
366
+ ]
367
+ pointers = CodeGenerator._name_for_pointers(tensor) + sum(
368
+ map(lambda x, y: x * y, intermediate_offsets, tensor.original.strides)
369
+ )
370
+ mask = functools.reduce(
371
+ lambda x, y: x & y,
372
+ (
373
+ offs[CodeGenerator._generate_slices(tensor, dim)] < size
374
+ for dim, (offs, size) in enumerate(zip(offsets, tensor.original.shape))
375
+ ),
376
+ )
377
+
378
+ return pointers, mask
379
+
380
+ @staticmethod
381
+ def _generate_other(tensor):
382
+ other = tensor.original.other
383
+
384
+ if isinstance(other, float) and not math.isfinite(other):
385
+ return f"float('{other}')"
386
+
387
+ return other
388
+
389
+ @staticmethod
390
+ def _generate_slices(tensor, dim):
391
+ return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
392
+
393
+ @staticmethod
394
+ def _generate_intermediate_offsets(tensor, intermediate_indices):
395
+ return tuple(
396
+ offs
397
+ for offs in tensor.offsets(
398
+ [0 for _ in range(tensor.ndim)]
399
+ + list(intermediate_indices)
400
+ + [0 for _ in range(tensor.inmost().ndim)]
401
+ )
402
+ )
403
+
404
+ @staticmethod
405
+ def _name_for_pointers(tensor):
406
+ return Symbol(f"{tensor.original.name}_pointers")
407
+
408
+ @staticmethod
409
+ def _name_for_offsets(tensor, dim):
410
+ return Symbol(f"{tensor.original.name}_offsets_{dim}")
338
411
 
339
412
 
340
413
  class Tritonizer(ast.NodeTransformer):
ninetoothed/language.py CHANGED
@@ -10,7 +10,10 @@ def call(func, *args, **kwargs):
10
10
  ast.Call(
11
11
  func=attribute(func).node,
12
12
  args=[Symbol(arg).node for arg in args],
13
- keywords=[(kwarg, Symbol(kwargs[kwarg]).node) for kwarg in kwargs],
13
+ keywords=[
14
+ ast.keyword(arg=kwarg, value=Symbol(kwargs[kwarg]).node)
15
+ for kwarg in kwargs
16
+ ],
14
17
  )
15
18
  )
16
19
 
ninetoothed/symbol.py CHANGED
@@ -78,16 +78,36 @@ class Symbol:
78
78
  )
79
79
 
80
80
  def __mod__(self, other):
81
+ other = type(self)(other)
82
+
83
+ return type(self)(ast.BinOp(left=self._node, op=ast.Mod(), right=other._node))
84
+
85
+ def __lt__(self, other):
86
+ other = type(self)(other)
87
+
88
+ return type(self)(
89
+ ast.Compare(left=self._node, ops=[ast.Lt()], comparators=[other._node])
90
+ )
91
+
92
+ def __and__(self, other):
93
+ other = type(self)(other)
94
+
81
95
  return type(self)(
82
- ast.BinOp(left=self._node, op=ast.Mod(), right=type(self)(other)._node)
96
+ ast.BinOp(left=self._node, op=ast.BitAnd(), right=other._node)
83
97
  )
84
98
 
99
+ def __rand__(self, other):
100
+ return self.__and__(other)
101
+
85
102
  def __getitem__(self, key):
86
103
  return type(self)(ast.Subscript(value=self._node, slice=type(self)(key)._node))
87
104
 
88
105
  def __repr__(self):
89
106
  return ast.unparse(self._node)
90
107
 
108
+ def find_and_replace(self, target, replacement):
109
+ _FindAndReplacer(target.node, replacement.node).visit(self._node)
110
+
91
111
  def names(self):
92
112
  class NameCollector(ast.NodeVisitor):
93
113
  def __init__(self):
@@ -130,3 +150,15 @@ class Symbol:
130
150
  @staticmethod
131
151
  def _create_meta(name):
132
152
  return f"_ninetoothed_meta_{name}"
153
+
154
+
155
+ class _FindAndReplacer(ast.NodeTransformer):
156
+ def __init__(self, target, replacement):
157
+ self._target_id = target.id
158
+ self._replacement = replacement
159
+
160
+ def visit_Name(self, node):
161
+ if node.id == self._target_id:
162
+ return self._replacement
163
+
164
+ return self.generic_visit(node)
ninetoothed/tensor.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import itertools
2
+ import re
2
3
 
3
4
  from ninetoothed.language import call
4
5
  from ninetoothed.symbol import Symbol
@@ -7,19 +8,24 @@ from ninetoothed.symbol import Symbol
7
8
  class Tensor:
8
9
  num_instances = 0
9
10
 
10
- def __init__(self, ndim=None, shape=None, dtype=None, strides=None, name=None):
11
+ def __init__(
12
+ self,
13
+ ndim=None,
14
+ shape=None,
15
+ dtype=None,
16
+ strides=None,
17
+ other=None,
18
+ original=None,
19
+ ):
11
20
  type(self).num_instances += 1
12
21
 
13
22
  self.dtype = dtype
14
23
 
15
- if name is not None:
16
- self.name = name
17
- else:
18
- self.name = f"tensor_{type(self).num_instances}"
24
+ self.name = f"tensor_{type(self).num_instances}"
19
25
 
20
26
  if ndim is not None:
21
- self.shape = [Symbol(f"{self.name}_size_{i}") for i in range(ndim)]
22
- self.strides = [Symbol(f"{self.name}_stride_{i}") for i in range(ndim)]
27
+ self.shape = [Symbol(self.size_string(i)) for i in range(ndim)]
28
+ self.strides = [Symbol(self.stride_string(i)) for i in range(ndim)]
23
29
  else:
24
30
  self.shape = shape
25
31
 
@@ -28,6 +34,13 @@ class Tensor:
28
34
  else:
29
35
  self.strides = self._calculate_default_strides(shape)
30
36
 
37
+ self.other = other
38
+
39
+ if original is not None:
40
+ self.original = original
41
+ else:
42
+ self.original = self
43
+
31
44
  def tile(self, tile_shape, tile_strides=None):
32
45
  if tile_strides is None:
33
46
  tile_strides = [1 for _ in tile_shape]
@@ -59,10 +72,10 @@ class Tensor:
59
72
  shape=inner_shape,
60
73
  dtype=self.dtype,
61
74
  strides=inner_strides,
62
- name=self.name,
75
+ original=self.original,
63
76
  ),
64
77
  strides=outer_strides,
65
- name=self.name,
78
+ original=self.original,
66
79
  )
67
80
 
68
81
  def expand(self, shape):
@@ -77,12 +90,21 @@ class Tensor:
77
90
  stride if new_size == -1 else 0
78
91
  for new_size, stride in zip(shape, self.strides)
79
92
  ],
80
- name=self.name,
93
+ original=self.original,
94
+ )
95
+
96
+ def squeeze(self, dim):
97
+ # TODO: Add error handling.
98
+ return type(self)(
99
+ shape=[size for i, size in enumerate(self.shape) if dim != i],
100
+ dtype=self.dtype,
101
+ strides=[stride for i, stride in enumerate(self.strides) if dim != i],
102
+ original=self.original,
81
103
  )
82
104
 
83
105
  def names(self):
84
106
  return (
85
- {self._pointer()}
107
+ {self.original.pointer_string()}
86
108
  | {
87
109
  name
88
110
  for value in itertools.chain(self.shape, self.strides)
@@ -92,35 +114,31 @@ class Tensor:
92
114
  | (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
93
115
  )
94
116
 
95
- def pointers(self, offsets=None):
96
- if offsets is None:
97
- offsets = self.offsets()
98
-
99
- return self._pointer() + offsets
100
-
101
117
  def offsets(self, indices=None):
102
118
  if indices is None:
103
119
  indices = self.indices()
104
120
 
105
- if not isinstance(self.dtype, type(self)):
106
- if len(indices) != self.ndim():
107
- raise IndexError("Incorrect number of indices.")
121
+ offsets = [[] for _ in range(self.original.ndim)]
108
122
 
109
- return sum(
110
- indices[idx]
111
- * self.stride(idx)
112
- * call("arange", 0, self.size(idx))[
113
- tuple(slice(None) if i == idx else None for i in range(self.ndim()))
114
- ]
115
- for idx in range(self.ndim())
116
- )
123
+ curr = self
124
+ start = 0
117
125
 
118
- outer_indices = indices[: self.ndim()]
119
- inner_indices = indices[self.ndim() :]
126
+ while isinstance(curr, type(self)):
127
+ stop = start + curr.ndim
128
+ curr_indices = indices[start:stop]
120
129
 
121
- return sum(
122
- index * stride for index, stride in zip(outer_indices, self.strides)
123
- ) + self.dtype.offsets(inner_indices)
130
+ for index, stride in zip(curr_indices, curr.strides):
131
+ for dim in self._dims_of(stride):
132
+ offsets[dim].append(index * stride)
133
+
134
+ start = stop
135
+ curr = curr.dtype
136
+
137
+ for dim in range(self.original.ndim):
138
+ offsets[dim] = sum(offsets[dim])
139
+ offsets[dim].find_and_replace(Symbol(self.original.strides[dim]), Symbol(1))
140
+
141
+ return offsets
124
142
 
125
143
  def indices(self, index=None):
126
144
  if index is None:
@@ -128,17 +146,22 @@ class Tensor:
128
146
 
129
147
  indices = []
130
148
 
131
- for stride in type(self)(shape=self.shape, name=self.name).strides:
149
+ for stride in type(self)(shape=self.shape, original=self.original).strides:
132
150
  indices.append(index // stride)
133
151
  index %= stride
134
152
 
135
153
  curr = self.dtype
136
- while isinstance(curr, type(self)):
137
- indices.extend(
138
- 0 if curr is not self.inmost() else 1 for _ in range(curr.ndim())
139
- )
154
+
155
+ while isinstance(curr.dtype, type(self)):
156
+ for _ in range(curr.ndim):
157
+ indices.append(0)
158
+
140
159
  curr = curr.dtype
141
160
 
161
+ if isinstance(curr, type(self)):
162
+ for dim in range(curr.ndim):
163
+ indices.append(call("arange", 0, curr.shape[dim]))
164
+
142
165
  return tuple(indices)
143
166
 
144
167
  def inmost(self):
@@ -147,8 +170,14 @@ class Tensor:
147
170
 
148
171
  return self.dtype.inmost()
149
172
 
150
- def ndim(self):
151
- return len(self.shape)
173
+ def pointer_string(self):
174
+ return f"{self.name}_pointer"
175
+
176
+ def size_string(self, dim):
177
+ return f"{self.name}_size_{dim}"
178
+
179
+ def stride_string(self, dim):
180
+ return f"{self.name}_stride_{dim}"
152
181
 
153
182
  def size(self, dim=None):
154
183
  if dim is None:
@@ -162,12 +191,31 @@ class Tensor:
162
191
 
163
192
  return self.strides[dim]
164
193
 
194
+ @property
195
+ def ndim(self):
196
+ return len(self.shape)
197
+
165
198
  @staticmethod
166
- def is_pointer(name):
167
- return name.endswith("_ptr")
199
+ def pointer_pattern():
200
+ return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
168
201
 
169
- def _pointer(self):
170
- return f"{self.name}_ptr"
202
+ @staticmethod
203
+ def size_pattern():
204
+ return re.compile(rf"({_identifier_pattern_raw_string()})_(size)_(.+)")
205
+
206
+ @staticmethod
207
+ def stride_pattern():
208
+ return re.compile(rf"({_identifier_pattern_raw_string()})_(stride)_(.+)")
209
+
210
+ def _dims_of(self, stride):
211
+ dims = set()
212
+ names = stride.names() if isinstance(stride, Symbol) else {stride}
213
+
214
+ for dim, original_stride in enumerate(self.original.strides):
215
+ if str(original_stride) in names:
216
+ dims.add(dim)
217
+
218
+ return dims
171
219
 
172
220
  @staticmethod
173
221
  def _calculate_default_strides(shape):
@@ -177,3 +225,7 @@ class Tensor:
177
225
  strides.append(size * strides[-1])
178
226
 
179
227
  return reversed(strides)
228
+
229
+
230
+ def _identifier_pattern_raw_string():
231
+ return r"[a-zA-Z_][a-zA-Z0-9_]*"
ninetoothed/torchifier.py CHANGED
@@ -1,23 +1,27 @@
1
1
  import ast
2
- import re
2
+
3
+ from ninetoothed.tensor import Tensor
3
4
 
4
5
 
5
6
  class Torchifier(ast.NodeTransformer):
6
7
  def visit_Name(self, node):
7
8
  self.generic_visit(node)
8
9
 
9
- pattern = re.compile(r"([a-zA-Z_][a-zA-Z0-9_]*)_(size|stride)_(.+)")
10
+ source = node.id
11
+
12
+ def repl(match):
13
+ return f"{match.group(1)}"
14
+
15
+ source = Tensor.pointer_pattern().sub(repl, source)
16
+
17
+ def repl(match):
18
+ return f"{match.group(1)}.{match.group(2)}({match.group(3)})"
10
19
 
11
- node.id = node.id.replace("_ptr", "")
20
+ source = Tensor.size_pattern().sub(repl, source)
21
+ source = Tensor.stride_pattern().sub(repl, source)
12
22
 
13
- if re.fullmatch(pattern, node.id):
14
- return ast.parse(
15
- pattern.sub(
16
- lambda match: f"{match.group(1)}.{match.group(2)}({match.group(3)})",
17
- node.id,
18
- ),
19
- mode="eval",
20
- ).body
23
+ if source != node.id:
24
+ return ast.parse(source, mode="eval").body
21
25
 
22
26
  return node
23
27
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ninetoothed
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
10
10
  Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
+ Requires-Dist: triton>=3.0.0
13
14
  Description-Content-Type: text/markdown
14
15
 
15
16
  # NineToothed
@@ -26,7 +27,7 @@ We can use `pip` to install `ninetoothed`.
26
27
  pip install ninetoothed
27
28
  ```
28
29
 
29
- After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install `triton` and a deep learning framework supported by `ninetoothed`. For trial purposes, we recommend installing `triton` and `torch`.
30
+ After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install a deep learning framework supported by `ninetoothed`. For trial purposes, we recommend installing `torch`.
30
31
 
31
32
  ## Usage
32
33
 
@@ -64,16 +65,19 @@ c_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
64
65
  a_tiled = a_tiled.expand((-1, c_tiled.shape[1]))
65
66
  b_tiled = b_tiled.expand((c_tiled.shape[0], -1))
66
67
 
68
+ a_tiled.dtype = a_tiled.dtype.squeeze(0)
69
+ b_tiled.dtype = b_tiled.dtype.squeeze(1)
70
+
67
71
  @ninetoothed.jit
68
72
  def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
69
73
  accumulator = ninetoothed.language.zeros(
70
74
  c.shape, dtype=ninetoothed.language.float32
71
75
  )
72
- for k in range(a.shape[1]):
73
- accumulator = ninetoothed.language.dot(a[0, k], b[k, 0], accumulator)
76
+ for k in range(a.shape[0]):
77
+ accumulator += ninetoothed.language.dot(a[k], b[k])
74
78
  c = accumulator.to(ninetoothed.language.float16)
75
79
  ```
76
80
 
77
- For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$.
81
+ For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
78
82
 
79
83
  With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of B, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
@@ -0,0 +1,10 @@
1
+ ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
+ ninetoothed/jit.py,sha256=nhjZRi8_kcjWZX0eOrnxLlzJfVg5vn12f9oi0Er2ABE,15515
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/symbol.py,sha256=Bd54qcI8KQAX0JRE_wPXycswtdSofhZ6Rr5MtZcv9fo,4665
5
+ ninetoothed/tensor.py,sha256=_DrjOJ-pBvEbSNUvUoYJduLQXmuKgNcqhe4xUDMVoZw,6275
6
+ ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
7
+ ninetoothed-0.3.0.dist-info/METADATA,sha256=CqdtfdV0eHzSwxJmFpD2IG5d4WTc6RDlpqMZue4Ml2Q,6720
8
+ ninetoothed-0.3.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ ninetoothed-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
+ ninetoothed-0.3.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
- ninetoothed/jit.py,sha256=hmUzkFZzsiKLgOHbsN0MAr1G1JCiyQ22cFPtmyZ1OyE,12725
3
- ninetoothed/language.py,sha256=cSuTgi5OwmLFy-dy_AHGZzRm18wz01ByHQ2vioP1vTg,437
4
- ninetoothed/symbol.py,sha256=I2Mc9D1w7AYAIQtyAXyDQ-FBqowVZrd-PK-JOt_SpgA,3787
5
- ninetoothed/tensor.py,sha256=RfwYzdYASkr6usJklESm1n8RoxvYjWnPtCjIfipa2fg,5000
6
- ninetoothed/torchifier.py,sha256=JmIVQE8r0zr_RLExsRDOGNsMu0F7v6J_o22aWqlw81k,841
7
- ninetoothed-0.2.0.dist-info/METADATA,sha256=w6qkc2riniG0N4nDUCUkZWF8Eve3j5brBQHIWIEqLXQ,5422
8
- ninetoothed-0.2.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- ninetoothed-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
- ninetoothed-0.2.0.dist-info/RECORD,,