tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/multi.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Optional, Union, Any, Tuple, List
|
2
|
+
from typing import Optional, Union, Any, Tuple, List, Dict
|
3
3
|
import functools, itertools, operator
|
4
|
-
from tinygrad.helpers import all_same, all_int, dedup,
|
4
|
+
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
|
5
5
|
from tinygrad.dtype import DType, ConstType
|
6
|
-
from tinygrad.ops import BinaryOps,
|
6
|
+
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, TernaryOps, ReduceOps
|
7
7
|
from tinygrad.lazy import LazyBuffer
|
8
8
|
from tinygrad.shape.shapetracker import sint
|
9
9
|
|
@@ -15,7 +15,7 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
|
15
15
|
n_lbs, dim = len(lbs), prod(lbs[0].shape)
|
16
16
|
# Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
17
17
|
# so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
18
|
-
use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
|
18
|
+
use_ring = (RING >= 2 or (n_lbs > 2 and dim > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
19
19
|
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
|
20
20
|
if not use_ring:
|
21
21
|
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
@@ -42,10 +42,9 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
|
42
42
|
pads = [((s,dim-e),) for s,e in chunks]
|
43
43
|
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
|
44
44
|
|
45
|
-
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
46
|
-
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
47
|
-
|
48
|
-
return [lb.shrink(tuple((0,s) if a != axis else (min(s,sz*i),min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
45
|
+
def to_sharded(lbs:List[LazyBuffer], axis:int, bounds: Tuple[Tuple[int, int], ...]) -> List[LazyBuffer]:
|
46
|
+
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
47
|
+
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
|
49
48
|
|
50
49
|
class MultiLazyBuffer:
|
51
50
|
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
|
@@ -54,11 +53,10 @@ class MultiLazyBuffer:
|
|
54
53
|
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
|
55
54
|
if axis is not None:
|
56
55
|
splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
|
57
|
-
self.bounds =
|
56
|
+
self.bounds = tuple(zip(splits, splits[1:]))
|
58
57
|
|
59
58
|
@property
|
60
|
-
def shape(self):
|
61
|
-
return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
59
|
+
def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
62
60
|
|
63
61
|
@property
|
64
62
|
def size(self): return sum(x.size for x in self.real_lbs)
|
@@ -66,13 +64,13 @@ class MultiLazyBuffer:
|
|
66
64
|
@property
|
67
65
|
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
|
68
66
|
|
69
|
-
def __repr__(self):
|
70
|
-
return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
67
|
+
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
71
68
|
|
72
69
|
@staticmethod
|
73
|
-
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]
|
74
|
-
|
75
|
-
|
70
|
+
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int], bounds:Optional[Tuple[Tuple[int, int], ...]]):
|
71
|
+
assert (axis is None) == (bounds is None), "must specify bounds iff axis is specified"
|
72
|
+
lbs = [lb] * len(devices)
|
73
|
+
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis, bounds) if axis is not None and bounds is not None else lbs, devices)]
|
76
74
|
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
|
77
75
|
|
78
76
|
def copy_to_device(self, device:str) -> LazyBuffer:
|
@@ -80,7 +78,8 @@ class MultiLazyBuffer:
|
|
80
78
|
# if we already have a copy on the device, return that
|
81
79
|
for lb in self.real_lbs:
|
82
80
|
if lb.device == device: return lb
|
83
|
-
return self.
|
81
|
+
return self.real_lbs[0].copy_to_device(device)
|
82
|
+
# copy lbs to device, pad to final shape, and sum
|
84
83
|
llbs:List[LazyBuffer] = []
|
85
84
|
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
86
85
|
if not real: continue
|
@@ -90,32 +89,32 @@ class MultiLazyBuffer:
|
|
90
89
|
|
91
90
|
# passthroughs
|
92
91
|
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
|
93
|
-
def cast(self, dtype:DType, bitcast:bool=False
|
92
|
+
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
|
93
|
+
return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real)
|
94
94
|
def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
|
95
95
|
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
96
96
|
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
97
97
|
|
98
98
|
# elementwise is simple
|
99
|
-
def e(self, op:Union[
|
99
|
+
def e(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
|
100
100
|
msrcs = (self,)+in_srcs
|
101
101
|
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
|
102
102
|
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
103
103
|
|
104
104
|
# NOTE: they all have to share an axis, we always choose [-1]
|
105
|
-
axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
|
106
|
-
srcs = []
|
105
|
+
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
|
106
|
+
srcs:List[List[LazyBuffer]] = []
|
107
107
|
not_all_real = any(not all(mlb.real) for mlb in msrcs)
|
108
108
|
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
|
109
109
|
assert any(new_real), "output contains no real lb"
|
110
110
|
for mlb in msrcs:
|
111
|
-
if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
|
112
|
-
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
|
113
|
-
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
111
|
+
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
|
112
|
+
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
113
|
+
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
114
|
+
new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].e(op, *lsrcs[1:], arg=arg) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
|
115
|
+
# NOTE: const dtype should match real
|
116
|
+
real_dtype = next(iter(new_real_lbs.values())).dtype
|
117
|
+
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const(0).cast(real_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
|
119
118
|
|
120
119
|
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
121
120
|
if self.axis is not None and self.axis in axis:
|
@@ -128,19 +127,21 @@ class MultiLazyBuffer:
|
|
128
127
|
|
129
128
|
# *** movement ops ***
|
130
129
|
|
130
|
+
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
|
131
|
+
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
132
|
+
|
131
133
|
def reshape(self, arg:Tuple[sint, ...]):
|
132
134
|
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
|
135
|
+
assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)"
|
133
136
|
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
134
137
|
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
135
138
|
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
136
139
|
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
|
143
|
-
new_axis, self.real)
|
140
|
+
assert all(prod(lb.shape[self.axis:]) % prod(arg[new_axis + 1:]) == 0 for lb in self.lbs),\
|
141
|
+
f"reshape cannot move items between shards {self.shape} {arg} {self.bounds}"
|
142
|
+
return MultiLazyBuffer([x.reshape(
|
143
|
+
tuple(s if a != new_axis else prod(x.shape[self.axis:]) // prod(arg[new_axis + 1:]) for a, s in enumerate(arg))
|
144
|
+
) for x in self.lbs], new_axis, self.real)
|
144
145
|
|
145
146
|
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
146
147
|
assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
|
@@ -152,13 +153,16 @@ class MultiLazyBuffer:
|
|
152
153
|
sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
|
153
154
|
return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
|
154
155
|
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
|
156
|
+
|
155
157
|
def expand(self, arg:Tuple[sint, ...]):
|
156
158
|
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
157
159
|
assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
|
158
160
|
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
|
161
|
+
|
159
162
|
def permute(self, arg:Tuple[int, ...]):
|
160
163
|
# all permutes supported!
|
161
164
|
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
|
165
|
+
|
162
166
|
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
|
163
167
|
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
|
164
168
|
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
|
@@ -168,6 +172,7 @@ class MultiLazyBuffer:
|
|
168
172
|
return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
|
169
173
|
return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
|
170
174
|
self.axis, self.real)
|
175
|
+
|
171
176
|
def stride(self, arg:Tuple[int, ...]):
|
172
177
|
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
|
173
178
|
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
|
tinygrad/nn/__init__.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
import math
|
2
|
-
from typing import Optional, Union, Tuple
|
2
|
+
from typing import Optional, Union, Tuple
|
3
3
|
from tinygrad.tensor import Tensor
|
4
4
|
from tinygrad.helpers import prod
|
5
5
|
from tinygrad.nn import optim, state, datasets # noqa: F401
|
6
6
|
|
7
|
-
class
|
7
|
+
class BatchNorm:
|
8
8
|
"""
|
9
|
-
Applies Batch Normalization over a
|
9
|
+
Applies Batch Normalization over a 2D or 3D input.
|
10
10
|
|
11
11
|
- Described: https://paperswithcode.com/method/batch-normalization
|
12
12
|
- Paper: https://arxiv.org/abs/1502.03167v3
|
@@ -20,7 +20,7 @@ class BatchNorm2d:
|
|
20
20
|
```
|
21
21
|
|
22
22
|
```python exec="true" source="above" session="tensor" result="python"
|
23
|
-
norm = nn.
|
23
|
+
norm = nn.BatchNorm(3)
|
24
24
|
t = Tensor.rand(2, 3, 4, 4)
|
25
25
|
print(t.mean().item(), t.std().item())
|
26
26
|
```
|
@@ -39,13 +39,14 @@ class BatchNorm2d:
|
|
39
39
|
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
40
40
|
|
41
41
|
def __call__(self, x:Tensor):
|
42
|
+
shape_mask = [1, -1, *([1]*(x.ndim-2))]
|
42
43
|
if Tensor.training:
|
43
44
|
# This requires two full memory accesses to x
|
44
45
|
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
45
46
|
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
46
|
-
batch_mean = x.mean(axis=(
|
47
|
-
y = (x - batch_mean.detach().reshape(shape=
|
48
|
-
batch_var = (y*y).mean(axis=
|
47
|
+
batch_mean = x.mean(axis=(reduce_axes:=tuple(x for x in range(x.ndim) if x != 1)))
|
48
|
+
y = (x - batch_mean.detach().reshape(shape=shape_mask)) # d(var)/d(mean) = 0
|
49
|
+
batch_var = (y*y).mean(axis=reduce_axes)
|
49
50
|
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
50
51
|
|
51
52
|
# NOTE: wow, this is done all throughout training in most PyTorch models
|
@@ -56,11 +57,10 @@ class BatchNorm2d:
|
|
56
57
|
else:
|
57
58
|
batch_mean = self.running_mean
|
58
59
|
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
|
59
|
-
batch_invstd = self.running_var.reshape(
|
60
|
-
|
60
|
+
batch_invstd = self.running_var.reshape(shape=shape_mask).expand(x.shape).add(self.eps).rsqrt()
|
61
61
|
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
|
62
|
+
BatchNorm2d = BatchNorm3d = BatchNorm
|
62
63
|
|
63
|
-
# TODO: these Conv lines are terrible
|
64
64
|
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
65
65
|
"""
|
66
66
|
Applies a 1D convolution over an input signal composed of several input planes.
|
@@ -98,16 +98,13 @@ class Conv2d:
|
|
98
98
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
99
99
|
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
100
100
|
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
101
|
-
|
102
|
-
|
103
|
-
self.bias = Tensor.uniform(out_channels, low=-
|
101
|
+
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
102
|
+
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
103
|
+
self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
|
104
104
|
|
105
105
|
def __call__(self, x:Tensor):
|
106
106
|
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
107
107
|
|
108
|
-
def initialize_weight(self, out_channels, in_channels, groups):
|
109
|
-
return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
110
|
-
|
111
108
|
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
112
109
|
"""
|
113
110
|
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
|
@@ -144,15 +141,14 @@ class ConvTranspose2d(Conv2d):
|
|
144
141
|
"""
|
145
142
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
146
143
|
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
144
|
+
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
145
|
+
self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
147
146
|
self.output_padding = output_padding
|
148
147
|
|
149
148
|
def __call__(self, x:Tensor):
|
150
149
|
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
|
151
150
|
dilation=self.dilation, groups=self.groups)
|
152
151
|
|
153
|
-
def initialize_weight(self, out_channels, in_channels, groups):
|
154
|
-
return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
155
|
-
|
156
152
|
class Linear:
|
157
153
|
"""
|
158
154
|
Applies a linear transformation to the incoming data.
|
@@ -170,9 +166,8 @@ class Linear:
|
|
170
166
|
```
|
171
167
|
"""
|
172
168
|
def __init__(self, in_features, out_features, bias=True):
|
173
|
-
# TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
|
174
|
-
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
175
169
|
bound = 1 / math.sqrt(in_features)
|
170
|
+
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
|
176
171
|
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
177
172
|
|
178
173
|
def __call__(self, x:Tensor):
|
@@ -282,6 +277,28 @@ class LayerNorm2d(LayerNorm):
|
|
282
277
|
"""
|
283
278
|
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
284
279
|
|
280
|
+
class RMSNorm:
|
281
|
+
"""
|
282
|
+
Applies Root Mean Square Normalization to input.
|
283
|
+
|
284
|
+
- Described: https://paperswithcode.com/method/rmsnorm
|
285
|
+
- Paper: https://arxiv.org/abs/1910.07467
|
286
|
+
|
287
|
+
```python exec="true" source="above" session="tensor" result="python"
|
288
|
+
norm = nn.RMSNorm(4)
|
289
|
+
t = Tensor.arange(12, dtype=dtypes.float).reshape(3, 4)
|
290
|
+
print(t.numpy())
|
291
|
+
```
|
292
|
+
```python exec="true" source="above" session="tensor" result="python"
|
293
|
+
print(norm(t).numpy())
|
294
|
+
```
|
295
|
+
"""
|
296
|
+
def __init__(self, dim, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
|
297
|
+
|
298
|
+
def _norm(self, x:Tensor): return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
299
|
+
|
300
|
+
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
|
301
|
+
|
285
302
|
class Embedding:
|
286
303
|
"""
|
287
304
|
A simple lookup table that stores embeddings of a fixed dictionary and size.
|
@@ -301,4 +318,4 @@ class Embedding:
|
|
301
318
|
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
302
319
|
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
303
320
|
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
|
304
|
-
return (arange == idx).mul(vals).sum(2)
|
321
|
+
return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
|
tinygrad/nn/state.py
CHANGED
@@ -159,8 +159,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
159
159
|
if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
|
160
160
|
assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
|
161
161
|
# TODO: find a nice way to support all shapetracker on disktensors
|
162
|
-
|
163
|
-
ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes).realize()
|
162
|
+
ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes)
|
164
163
|
|
165
164
|
return ret.reshape(size)
|
166
165
|
|
@@ -168,7 +167,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
168
167
|
def __setstate__(self, state): self.tensor = state[0]
|
169
168
|
|
170
169
|
deserialized_objects: Dict[str, Any] = {}
|
171
|
-
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
|
170
|
+
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
|
171
|
+
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
|
172
172
|
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
173
173
|
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
174
174
|
class Dummy: pass
|
tinygrad/ops.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Union, Tuple, Any, List, Dict, Callable
|
|
3
3
|
import functools, hashlib, math, operator, ctypes, struct
|
4
4
|
from enum import Enum, auto
|
5
5
|
from dataclasses import dataclass
|
6
|
-
from tinygrad.helpers import prod, dedup
|
6
|
+
from tinygrad.helpers import prod, dedup, pretty_print
|
7
7
|
from tinygrad.dtype import dtypes, DType, ConstType
|
8
8
|
from tinygrad.shape.symbolic import Variable, sint
|
9
9
|
from tinygrad.shape.shapetracker import ShapeTracker
|
@@ -18,17 +18,17 @@ class UnaryOps(Enum):
|
|
18
18
|
class BinaryOps(Enum):
|
19
19
|
"""A + A -> A (elementwise)"""
|
20
20
|
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
21
|
-
SHR = auto();
|
21
|
+
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
|
22
22
|
class TernaryOps(Enum):
|
23
23
|
"""A + A + A -> A (elementwise)"""
|
24
24
|
WHERE = auto(); MULACC = auto() # noqa: E702
|
25
25
|
class ReduceOps(Enum):
|
26
26
|
"""A -> B (reduce)"""
|
27
|
-
SUM = auto(); MAX = auto() # noqa: E702
|
27
|
+
SUM = auto(); MAX = auto(); WMMA = auto() # noqa: E702
|
28
28
|
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
29
|
-
class
|
30
|
-
|
31
|
-
Op = Union[UnaryOps, BinaryOps, ReduceOps,
|
29
|
+
class MetaOps(Enum):
|
30
|
+
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto(); KERNEL = auto(); EXT = auto() # noqa: E702
|
31
|
+
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, BufferOps]
|
32
32
|
|
33
33
|
# do not preserve f(0) = 0
|
34
34
|
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
|
@@ -45,6 +45,12 @@ class ConstBuffer:
|
|
45
45
|
dtype: DType
|
46
46
|
st: ShapeTracker
|
47
47
|
|
48
|
+
@dataclass(frozen=True)
|
49
|
+
class KernelInfo:
|
50
|
+
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
51
|
+
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
|
52
|
+
dont_use_locals: bool = False # don't use local indexing
|
53
|
+
|
48
54
|
@dataclass(frozen=True, eq=False)
|
49
55
|
class LazyOp:
|
50
56
|
op: Op
|
@@ -57,13 +63,17 @@ class LazyOp:
|
|
57
63
|
ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
|
58
64
|
return ret
|
59
65
|
def __eq__(self, x): return self.cached_compare(x, context={})
|
60
|
-
def __repr__(self): return f
|
66
|
+
def __repr__(self:LazyOp): return pretty_print(self, lambda x: f'LazyOp({x.op}, arg={x.arg}, src=(%s))')
|
61
67
|
@functools.cached_property
|
62
68
|
def dtype(self) -> DType:
|
63
69
|
if self.op in BufferOps: return self.arg.dtype
|
70
|
+
if self.op is ReduceOps.WMMA: return self.arg[3] # WMMA can change the type
|
64
71
|
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
|
65
72
|
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
|
66
|
-
|
73
|
+
@functools.cached_property
|
74
|
+
def full_shape(self) -> Tuple[sint, ...]:
|
75
|
+
if len(self.src) == 0 and self.op in BufferOps: return self.arg.st.shape
|
76
|
+
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src]))
|
67
77
|
@functools.cached_property
|
68
78
|
def key(self) -> bytes:
|
69
79
|
return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
|
@@ -77,35 +87,16 @@ class LazyOp:
|
|
77
87
|
const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
|
78
88
|
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
|
79
89
|
|
80
|
-
#
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
@
|
88
|
-
def
|
89
|
-
|
90
|
-
self.flops, ret = 0, self.flops
|
91
|
-
return ret
|
92
|
-
|
93
|
-
InterpretedFlopCounter: Dict[Op, Callable] = {
|
94
|
-
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
95
|
-
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
|
96
|
-
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
97
|
-
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
|
98
|
-
UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
|
99
|
-
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
|
100
|
-
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
101
|
-
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
|
102
|
-
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
103
|
-
|
104
|
-
@functools.lru_cache(None)
|
105
|
-
def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
106
|
-
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
|
107
|
-
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
|
108
|
-
return run_ast(ast)
|
90
|
+
# TODO: support non-lazyop
|
91
|
+
def __add__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, x))
|
92
|
+
def __sub__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, -x))
|
93
|
+
def __mul__(self, x:LazyOp): return LazyOp(BinaryOps.MUL, (self, x))
|
94
|
+
def ne(self, x:LazyOp): return LazyOp(BinaryOps.CMPNE, (self, x))
|
95
|
+
def eq(self, x:LazyOp): return -self.ne(x)
|
96
|
+
def __neg__(self): return LazyOp(UnaryOps.NEG, (self,))
|
97
|
+
@staticmethod
|
98
|
+
def const(val, dtype:DType, shape:Tuple[sint, ...]):
|
99
|
+
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)))
|
109
100
|
|
110
101
|
# **************** ops in python ****************
|
111
102
|
|
@@ -115,18 +106,15 @@ def hook_overflow(dv, fxn):
|
|
115
106
|
except OverflowError: return dv
|
116
107
|
return wfxn
|
117
108
|
|
118
|
-
python_alu = {
|
119
|
-
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
|
120
|
-
UnaryOps.
|
121
|
-
UnaryOps.
|
122
|
-
|
123
|
-
UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
124
|
-
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
125
|
-
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
|
126
|
-
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
|
109
|
+
python_alu: Dict[Op, Callable] = {
|
110
|
+
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
111
|
+
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
112
|
+
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
113
|
+
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
|
127
114
|
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
128
|
-
BinaryOps.
|
129
|
-
|
115
|
+
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
|
116
|
+
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
117
|
+
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
130
118
|
|
131
119
|
def truncate_fp16(x):
|
132
120
|
try:
|
@@ -140,30 +128,43 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
|
140
128
|
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
141
129
|
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
142
130
|
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
143
|
-
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
|
144
|
-
|
131
|
+
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
|
132
|
+
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
|
145
133
|
|
146
134
|
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
147
135
|
|
136
|
+
def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
|
137
|
+
|
148
138
|
# the living definition of LazyOps
|
149
|
-
def verify_lazyop(
|
139
|
+
def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
|
140
|
+
assert ast.op is MetaOps.KERNEL, "must be SINK"
|
150
141
|
sts: Dict[LazyOp, ShapeTracker] = {}
|
151
|
-
def
|
142
|
+
def assert_valid(op:LazyOp, st:ShapeTracker):
|
152
143
|
if op in sts: return
|
153
|
-
|
144
|
+
# restore globals from the two stage reduce
|
145
|
+
if op.op is BufferOps.LOAD and op.arg.idx < 0:
|
146
|
+
assert_valid(local_reduce:=op.src[0].src[0], op.arg.st)
|
147
|
+
return sts.setdefault(op, sts[local_reduce])
|
148
|
+
for x in op.src: assert_valid(x, st)
|
154
149
|
# only reduceop is allowed to change shape, limited to turning n to 1
|
155
150
|
if op.op in ReduceOps:
|
156
|
-
|
157
|
-
assert
|
158
|
-
st = ShapeTracker.from_shape(
|
151
|
+
axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
|
152
|
+
assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
|
153
|
+
st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis))
|
159
154
|
else:
|
160
155
|
# movementops are pushed to the edges with LOAD
|
161
|
-
|
162
|
-
|
163
|
-
for x in op.src:
|
156
|
+
# elementwise inherits shape
|
157
|
+
st = op.arg.st if op.op in BufferOps else sts[op.src[0]]
|
158
|
+
for x in op.src:
|
159
|
+
if sts[x].shape != st.shape:
|
160
|
+
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
|
161
|
+
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
164
162
|
sts[op] = st
|
165
|
-
for i, out in enumerate(ast):
|
163
|
+
for i, out in enumerate(ast.src):
|
166
164
|
assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
|
167
165
|
assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
|
168
|
-
assert out.arg.st.size == ast[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
|
169
|
-
|
166
|
+
assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
|
167
|
+
assert_valid(out, out.arg.st)
|
168
|
+
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
169
|
+
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
170
|
+
return sts
|
tinygrad/renderer/__init__.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import Optional, List, Tuple, Dict
|
1
|
+
from typing import Optional, List, Tuple, Dict, Callable, Any
|
2
2
|
import functools
|
3
|
-
from dataclasses import dataclass
|
4
|
-
from tinygrad.helpers import
|
5
|
-
from tinygrad.codegen.uops import
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from tinygrad.helpers import to_function_name, dedup
|
5
|
+
from tinygrad.codegen.uops import UOps, UOp, flops_mem
|
6
|
+
from tinygrad.ops import Op
|
6
7
|
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
7
8
|
from tinygrad.dtype import DType
|
8
9
|
|
@@ -12,30 +13,53 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
|
12
13
|
dtype_in: DType # dtype for A and B
|
13
14
|
dtype_out: DType # dtype for C and D
|
14
15
|
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
15
|
-
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
|
16
|
-
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
|
17
16
|
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
18
|
-
def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
|
19
17
|
|
20
|
-
@dataclass
|
18
|
+
@dataclass
|
21
19
|
class Program:
|
22
20
|
name:str
|
23
21
|
src:str
|
24
22
|
dname:str
|
23
|
+
uops:Optional[List[UOp]]=None
|
24
|
+
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
|
25
|
+
|
26
|
+
# filled in from uops (if we have uops)
|
25
27
|
global_size:Optional[List[int]]=None
|
26
28
|
local_size:Optional[List[int]]=None
|
27
|
-
|
28
|
-
|
29
|
-
|
29
|
+
vars:List[Variable]=field(default_factory=list)
|
30
|
+
globals:List[int]=field(default_factory=list)
|
31
|
+
outs:List[int]=field(default_factory=list)
|
32
|
+
_ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
|
30
33
|
|
31
|
-
|
32
|
-
|
34
|
+
def __post_init__(self):
|
35
|
+
if not self._ran_post_init and self.uops is not None:
|
36
|
+
# single pass through the uops
|
37
|
+
for u in self.uops:
|
38
|
+
if u.op is UOps.DEFINE_VAR: self.vars.append(u.arg)
|
39
|
+
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
|
40
|
+
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
|
41
|
+
if u.op is UOps.SPECIAL:
|
42
|
+
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
43
|
+
if u.arg[0][0] == 'i': self.local_size = None
|
44
|
+
if u.arg[0][0] == 'l':
|
45
|
+
assert self.local_size is not None
|
46
|
+
self.local_size[int(u.arg[0][-1])] = u.arg[1]
|
47
|
+
else:
|
48
|
+
assert self.global_size is not None
|
49
|
+
self.global_size[int(u.arg[0][-1])] = u.arg[1]
|
50
|
+
self.vars = sorted(self.vars, key=lambda v: v.expr)
|
51
|
+
self.outs = sorted(dedup(self.outs))
|
52
|
+
self._ran_post_init = True
|
33
53
|
|
54
|
+
@property
|
55
|
+
def op_estimate(self) -> sint: return self._ops_lds[0]
|
56
|
+
@property
|
57
|
+
def lds_estimate(self) -> sint: return self._ops_lds[1]
|
34
58
|
@functools.cached_property
|
35
|
-
def
|
59
|
+
def _ops_lds(self) -> Tuple[sint, sint]: return (0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True)
|
36
60
|
|
37
|
-
@
|
38
|
-
def outcount(self) -> int: return
|
61
|
+
@property
|
62
|
+
def outcount(self) -> int: return len(self.outs)
|
39
63
|
|
40
64
|
@functools.cached_property
|
41
65
|
def function_name(self) -> str: return to_function_name(self.name)
|
@@ -57,9 +81,7 @@ class Renderer:
|
|
57
81
|
local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
|
58
82
|
shared_max: int = 32768
|
59
83
|
tensor_cores: List[TensorCore] = []
|
60
|
-
|
61
|
-
|
62
|
-
@functools.cached_property
|
63
|
-
def tc(self): return getenv("TC", 1)
|
84
|
+
extra_matcher: Any = None
|
85
|
+
code_for_op: Dict[Op, Callable] = {}
|
64
86
|
|
65
|
-
def render(self, name:str, uops:
|
87
|
+
def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")
|