tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/multi.py CHANGED
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, Union, Any, Tuple, List
2
+ from typing import Optional, Union, Any, Tuple, List, Dict
3
3
  import functools, itertools, operator
4
- from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING
4
+ from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
5
5
  from tinygrad.dtype import DType, ConstType
6
- from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
6
+ from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, TernaryOps, ReduceOps
7
7
  from tinygrad.lazy import LazyBuffer
8
8
  from tinygrad.shape.shapetracker import sint
9
9
 
@@ -15,7 +15,7 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
15
15
  n_lbs, dim = len(lbs), prod(lbs[0].shape)
16
16
  # Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
17
17
  # so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
18
- use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
18
+ use_ring = (RING >= 2 or (n_lbs > 2 and dim > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
19
19
  if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
20
20
  if not use_ring:
21
21
  return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
@@ -42,10 +42,9 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
42
42
  pads = [((s,dim-e),) for s,e in chunks]
43
43
  return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
44
44
 
45
- def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
46
- if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
47
- sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
48
- return [lb.shrink(tuple((0,s) if a != axis else (min(s,sz*i),min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
45
+ def to_sharded(lbs:List[LazyBuffer], axis:int, bounds: Tuple[Tuple[int, int], ...]) -> List[LazyBuffer]:
46
+ if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
47
+ return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
49
48
 
50
49
  class MultiLazyBuffer:
51
50
  def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
@@ -54,11 +53,10 @@ class MultiLazyBuffer:
54
53
  self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
55
54
  if axis is not None:
56
55
  splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
57
- self.bounds = list(zip(splits, splits[1:]))
56
+ self.bounds = tuple(zip(splits, splits[1:]))
58
57
 
59
58
  @property
60
- def shape(self):
61
- return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
59
+ def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
62
60
 
63
61
  @property
64
62
  def size(self): return sum(x.size for x in self.real_lbs)
@@ -66,13 +64,13 @@ class MultiLazyBuffer:
66
64
  @property
67
65
  def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
68
66
 
69
- def __repr__(self):
70
- return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
67
+ def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
71
68
 
72
69
  @staticmethod
73
- def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
74
- lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
75
- sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
70
+ def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int], bounds:Optional[Tuple[Tuple[int, int], ...]]):
71
+ assert (axis is None) == (bounds is None), "must specify bounds iff axis is specified"
72
+ lbs = [lb] * len(devices)
73
+ sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis, bounds) if axis is not None and bounds is not None else lbs, devices)]
76
74
  return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
77
75
 
78
76
  def copy_to_device(self, device:str) -> LazyBuffer:
@@ -80,7 +78,8 @@ class MultiLazyBuffer:
80
78
  # if we already have a copy on the device, return that
81
79
  for lb in self.real_lbs:
82
80
  if lb.device == device: return lb
83
- return self.lbs[self.real.index(True)].copy_to_device(device)
81
+ return self.real_lbs[0].copy_to_device(device)
82
+ # copy lbs to device, pad to final shape, and sum
84
83
  llbs:List[LazyBuffer] = []
85
84
  for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
86
85
  if not real: continue
@@ -90,32 +89,32 @@ class MultiLazyBuffer:
90
89
 
91
90
  # passthroughs
92
91
  def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
93
- def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
92
+ def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
93
+ return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real)
94
94
  def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
95
95
  def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
96
96
  def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
97
97
 
98
98
  # elementwise is simple
99
- def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
99
+ def e(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
100
100
  msrcs = (self,)+in_srcs
101
101
  assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
102
102
  assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
103
103
 
104
104
  # NOTE: they all have to share an axis, we always choose [-1]
105
- axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
106
- srcs = []
105
+ axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
106
+ srcs:List[List[LazyBuffer]] = []
107
107
  not_all_real = any(not all(mlb.real) for mlb in msrcs)
108
108
  new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
109
109
  assert any(new_real), "output contains no real lb"
110
110
  for mlb in msrcs:
111
- if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
112
- elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
113
- else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
114
- # NOTE: lsrcs[-1].const(0) is correct for where
115
- return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) if r else lsrcs[-1].const(0) for lsrcs,r in zip(zip(*srcs),new_real)], axis, new_real)
116
-
117
- def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
118
- return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
111
+ if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
112
+ elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
113
+ else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
114
+ new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].e(op, *lsrcs[1:], arg=arg) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
115
+ # NOTE: const dtype should match real
116
+ real_dtype = next(iter(new_real_lbs.values())).dtype
117
+ return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const(0).cast(real_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
119
118
 
120
119
  def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
121
120
  if self.axis is not None and self.axis in axis:
@@ -128,19 +127,21 @@ class MultiLazyBuffer:
128
127
 
129
128
  # *** movement ops ***
130
129
 
130
+ def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
131
+ return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
132
+
131
133
  def reshape(self, arg:Tuple[sint, ...]):
132
134
  if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
135
+ assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)"
133
136
  arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
134
137
  # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
135
138
  # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
136
139
  new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
137
- if arg[new_axis] != self.shape[self.axis]:
138
- assert self.shape[self.axis] % len(self.real_lbs) == 0, f"cannot reshape on-axis for uneven shard {self.axis} {self.shape} {len(self.real_lbs)}"
139
- assert arg[new_axis] % len(self.real_lbs) == 0, f"new on-axis shape must divide evenly between devices {new_axis} {arg} {len(self.real_lbs)}"
140
- return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
141
- x.shape[self.axis] if s == self.shape[self.axis] else
142
- s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
143
- new_axis, self.real)
140
+ assert all(prod(lb.shape[self.axis:]) % prod(arg[new_axis + 1:]) == 0 for lb in self.lbs),\
141
+ f"reshape cannot move items between shards {self.shape} {arg} {self.bounds}"
142
+ return MultiLazyBuffer([x.reshape(
143
+ tuple(s if a != new_axis else prod(x.shape[self.axis:]) // prod(arg[new_axis + 1:]) for a, s in enumerate(arg))
144
+ ) for x in self.lbs], new_axis, self.real)
144
145
 
145
146
  def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
146
147
  assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
@@ -152,13 +153,16 @@ class MultiLazyBuffer:
152
153
  sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
153
154
  return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
154
155
  return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
156
+
155
157
  def expand(self, arg:Tuple[sint, ...]):
156
158
  # NOTE: this assert isn't needed, sharded axis can have dim 1
157
159
  assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
158
160
  return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
161
+
159
162
  def permute(self, arg:Tuple[int, ...]):
160
163
  # all permutes supported!
161
164
  return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
165
+
162
166
  def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
163
167
  assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
164
168
  if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
@@ -168,6 +172,7 @@ class MultiLazyBuffer:
168
172
  return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
169
173
  return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
170
174
  self.axis, self.real)
175
+
171
176
  def stride(self, arg:Tuple[int, ...]):
172
177
  assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
173
178
  return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
tinygrad/nn/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
1
  import math
2
- from typing import Optional, Union, Tuple, cast
2
+ from typing import Optional, Union, Tuple
3
3
  from tinygrad.tensor import Tensor
4
4
  from tinygrad.helpers import prod
5
5
  from tinygrad.nn import optim, state, datasets # noqa: F401
6
6
 
7
- class BatchNorm2d:
7
+ class BatchNorm:
8
8
  """
9
- Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension).
9
+ Applies Batch Normalization over a 2D or 3D input.
10
10
 
11
11
  - Described: https://paperswithcode.com/method/batch-normalization
12
12
  - Paper: https://arxiv.org/abs/1502.03167v3
@@ -20,7 +20,7 @@ class BatchNorm2d:
20
20
  ```
21
21
 
22
22
  ```python exec="true" source="above" session="tensor" result="python"
23
- norm = nn.BatchNorm2d(3)
23
+ norm = nn.BatchNorm(3)
24
24
  t = Tensor.rand(2, 3, 4, 4)
25
25
  print(t.mean().item(), t.std().item())
26
26
  ```
@@ -39,13 +39,14 @@ class BatchNorm2d:
39
39
  self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
40
40
 
41
41
  def __call__(self, x:Tensor):
42
+ shape_mask = [1, -1, *([1]*(x.ndim-2))]
42
43
  if Tensor.training:
43
44
  # This requires two full memory accesses to x
44
45
  # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
45
46
  # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
46
- batch_mean = x.mean(axis=(0,2,3))
47
- y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
48
- batch_var = (y*y).mean(axis=(0,2,3))
47
+ batch_mean = x.mean(axis=(reduce_axes:=tuple(x for x in range(x.ndim) if x != 1)))
48
+ y = (x - batch_mean.detach().reshape(shape=shape_mask)) # d(var)/d(mean) = 0
49
+ batch_var = (y*y).mean(axis=reduce_axes)
49
50
  batch_invstd = batch_var.add(self.eps).pow(-0.5)
50
51
 
51
52
  # NOTE: wow, this is done all throughout training in most PyTorch models
@@ -56,11 +57,10 @@ class BatchNorm2d:
56
57
  else:
57
58
  batch_mean = self.running_mean
58
59
  # NOTE: this can be precomputed for static inference. we expand it here so it fuses
59
- batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
60
-
60
+ batch_invstd = self.running_var.reshape(shape=shape_mask).expand(x.shape).add(self.eps).rsqrt()
61
61
  return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
62
+ BatchNorm2d = BatchNorm3d = BatchNorm
62
63
 
63
- # TODO: these Conv lines are terrible
64
64
  def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
65
65
  """
66
66
  Applies a 1D convolution over an input signal composed of several input planes.
@@ -98,16 +98,13 @@ class Conv2d:
98
98
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
99
99
  self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
100
100
  self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
101
- self.weight = self.initialize_weight(out_channels, in_channels, groups)
102
- bound = 1 / math.sqrt(cast(int, prod(self.weight.shape[1:]))) # weight shape is always ints but mypy cannot tell
103
- self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
101
+ scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
102
+ self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
103
+ self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
104
104
 
105
105
  def __call__(self, x:Tensor):
106
106
  return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
107
107
 
108
- def initialize_weight(self, out_channels, in_channels, groups):
109
- return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
110
-
111
108
  def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
112
109
  """
113
110
  Applies a 1D transposed convolution operator over an input signal composed of several input planes.
@@ -144,15 +141,14 @@ class ConvTranspose2d(Conv2d):
144
141
  """
145
142
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
146
143
  super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
144
+ scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
145
+ self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
147
146
  self.output_padding = output_padding
148
147
 
149
148
  def __call__(self, x:Tensor):
150
149
  return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
151
150
  dilation=self.dilation, groups=self.groups)
152
151
 
153
- def initialize_weight(self, out_channels, in_channels, groups):
154
- return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
155
-
156
152
  class Linear:
157
153
  """
158
154
  Applies a linear transformation to the incoming data.
@@ -170,9 +166,8 @@ class Linear:
170
166
  ```
171
167
  """
172
168
  def __init__(self, in_features, out_features, bias=True):
173
- # TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
174
- self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
175
169
  bound = 1 / math.sqrt(in_features)
170
+ self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
176
171
  self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
177
172
 
178
173
  def __call__(self, x:Tensor):
@@ -282,6 +277,28 @@ class LayerNorm2d(LayerNorm):
282
277
  """
283
278
  def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
284
279
 
280
+ class RMSNorm:
281
+ """
282
+ Applies Root Mean Square Normalization to input.
283
+
284
+ - Described: https://paperswithcode.com/method/rmsnorm
285
+ - Paper: https://arxiv.org/abs/1910.07467
286
+
287
+ ```python exec="true" source="above" session="tensor" result="python"
288
+ norm = nn.RMSNorm(4)
289
+ t = Tensor.arange(12, dtype=dtypes.float).reshape(3, 4)
290
+ print(t.numpy())
291
+ ```
292
+ ```python exec="true" source="above" session="tensor" result="python"
293
+ print(norm(t).numpy())
294
+ ```
295
+ """
296
+ def __init__(self, dim, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
297
+
298
+ def _norm(self, x:Tensor): return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
299
+
300
+ def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
301
+
285
302
  class Embedding:
286
303
  """
287
304
  A simple lookup table that stores embeddings of a fixed dictionary and size.
@@ -301,4 +318,4 @@ class Embedding:
301
318
  arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
302
319
  if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
303
320
  arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
304
- return (arange == idx).mul(vals).sum(2)
321
+ return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
tinygrad/nn/state.py CHANGED
@@ -159,8 +159,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
159
159
  if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
160
160
  assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
161
161
  # TODO: find a nice way to support all shapetracker on disktensors
162
- # TODO: BUG: a ".realize()" is needed here for 'GPU=1 python3 test/models/test_efficientnet.py TestEfficientNet.test_car'
163
- ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes).realize()
162
+ ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes)
164
163
 
165
164
  return ret.reshape(size)
166
165
 
@@ -168,7 +167,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
168
167
  def __setstate__(self, state): self.tensor = state[0]
169
168
 
170
169
  deserialized_objects: Dict[str, Any] = {}
171
- intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32,
170
+ intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
171
+ "IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
172
172
  "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
173
173
  whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
174
174
  class Dummy: pass
tinygrad/ops.py CHANGED
@@ -3,7 +3,7 @@ from typing import Union, Tuple, Any, List, Dict, Callable
3
3
  import functools, hashlib, math, operator, ctypes, struct
4
4
  from enum import Enum, auto
5
5
  from dataclasses import dataclass
6
- from tinygrad.helpers import prod, dedup
6
+ from tinygrad.helpers import prod, dedup, pretty_print
7
7
  from tinygrad.dtype import dtypes, DType, ConstType
8
8
  from tinygrad.shape.symbolic import Variable, sint
9
9
  from tinygrad.shape.shapetracker import ShapeTracker
@@ -18,17 +18,17 @@ class UnaryOps(Enum):
18
18
  class BinaryOps(Enum):
19
19
  """A + A -> A (elementwise)"""
20
20
  ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
21
- SHR = auto(); SHL = auto() # noqa: E702
21
+ SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
22
22
  class TernaryOps(Enum):
23
23
  """A + A + A -> A (elementwise)"""
24
24
  WHERE = auto(); MULACC = auto() # noqa: E702
25
25
  class ReduceOps(Enum):
26
26
  """A -> B (reduce)"""
27
- SUM = auto(); MAX = auto() # noqa: E702
27
+ SUM = auto(); MAX = auto(); WMMA = auto() # noqa: E702
28
28
  class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
29
- class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
30
-
31
- Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
29
+ class MetaOps(Enum):
30
+ EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto(); KERNEL = auto(); EXT = auto() # noqa: E702
31
+ Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, BufferOps]
32
32
 
33
33
  # do not preserve f(0) = 0
34
34
  UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
@@ -45,6 +45,12 @@ class ConstBuffer:
45
45
  dtype: DType
46
46
  st: ShapeTracker
47
47
 
48
+ @dataclass(frozen=True)
49
+ class KernelInfo:
50
+ local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
51
+ upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
52
+ dont_use_locals: bool = False # don't use local indexing
53
+
48
54
  @dataclass(frozen=True, eq=False)
49
55
  class LazyOp:
50
56
  op: Op
@@ -57,13 +63,17 @@ class LazyOp:
57
63
  ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
58
64
  return ret
59
65
  def __eq__(self, x): return self.cached_compare(x, context={})
60
- def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
66
+ def __repr__(self:LazyOp): return pretty_print(self, lambda x: f'LazyOp({x.op}, arg={x.arg}, src=(%s))')
61
67
  @functools.cached_property
62
68
  def dtype(self) -> DType:
63
69
  if self.op in BufferOps: return self.arg.dtype
70
+ if self.op is ReduceOps.WMMA: return self.arg[3] # WMMA can change the type
64
71
  if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
65
72
  return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
66
-
73
+ @functools.cached_property
74
+ def full_shape(self) -> Tuple[sint, ...]:
75
+ if len(self.src) == 0 and self.op in BufferOps: return self.arg.st.shape
76
+ return tuple(max(x) for x in zip(*[x.full_shape for x in self.src]))
67
77
  @functools.cached_property
68
78
  def key(self) -> bytes:
69
79
  return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
@@ -77,35 +87,16 @@ class LazyOp:
77
87
  const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
78
88
  return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
79
89
 
80
- # **************** independent FlopCounter ****************
81
-
82
- @dataclass
83
- class FlopCounter:
84
- shape: Tuple[int, ...]
85
- flops: sint
86
- mem: Dict[int, int]
87
- @property
88
- def mem_estimate(self): return sum(self.mem.values())
89
- def consume_flops(self):
90
- self.flops, ret = 0, self.flops
91
- return ret
92
-
93
- InterpretedFlopCounter: Dict[Op, Callable] = {
94
- BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
95
- BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
96
- BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
97
- UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
98
- UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
99
- **{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
100
- **{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
101
- **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
102
- TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
103
-
104
- @functools.lru_cache(None)
105
- def get_lazyop_info(ast:LazyOp) -> FlopCounter:
106
- @functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
107
- def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
108
- return run_ast(ast)
90
+ # TODO: support non-lazyop
91
+ def __add__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, x))
92
+ def __sub__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, -x))
93
+ def __mul__(self, x:LazyOp): return LazyOp(BinaryOps.MUL, (self, x))
94
+ def ne(self, x:LazyOp): return LazyOp(BinaryOps.CMPNE, (self, x))
95
+ def eq(self, x:LazyOp): return -self.ne(x)
96
+ def __neg__(self): return LazyOp(UnaryOps.NEG, (self,))
97
+ @staticmethod
98
+ def const(val, dtype:DType, shape:Tuple[sint, ...]):
99
+ return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)))
109
100
 
110
101
  # **************** ops in python ****************
111
102
 
@@ -115,18 +106,15 @@ def hook_overflow(dv, fxn):
115
106
  except OverflowError: return dv
116
107
  return wfxn
117
108
 
118
- python_alu = {
119
- UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
120
- UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
121
- UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
122
- UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
123
- UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
124
- UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
125
- BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
126
- BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
109
+ python_alu: Dict[Op, Callable] = {
110
+ UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
111
+ UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
112
+ UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
113
+ BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
127
114
  BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
128
- BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
129
- TernaryOps.WHERE: lambda x,y,z: y if x else z}
115
+ BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
116
+ BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
117
+ TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
130
118
 
131
119
  def truncate_fp16(x):
132
120
  try:
@@ -140,30 +128,43 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
140
128
  dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
141
129
  dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
142
130
  dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
143
- dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
144
- dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
131
+ dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
132
+ if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
145
133
 
146
134
  def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
147
135
 
136
+ def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
137
+
148
138
  # the living definition of LazyOps
149
- def verify_lazyop(*ast:LazyOp):
139
+ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
140
+ assert ast.op is MetaOps.KERNEL, "must be SINK"
150
141
  sts: Dict[LazyOp, ShapeTracker] = {}
151
- def dfs(op:LazyOp, st:ShapeTracker):
142
+ def assert_valid(op:LazyOp, st:ShapeTracker):
152
143
  if op in sts: return
153
- for x in op.src: dfs(x, st)
144
+ # restore globals from the two stage reduce
145
+ if op.op is BufferOps.LOAD and op.arg.idx < 0:
146
+ assert_valid(local_reduce:=op.src[0].src[0], op.arg.st)
147
+ return sts.setdefault(op, sts[local_reduce])
148
+ for x in op.src: assert_valid(x, st)
154
149
  # only reduceop is allowed to change shape, limited to turning n to 1
155
150
  if op.op in ReduceOps:
156
- expected_shape = tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape))
157
- assert st.shape == expected_shape, f"unexpected reduceop shape {st.shape} != {expected_shape}"
158
- st = ShapeTracker.from_shape(expected_shape)
151
+ axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
152
+ assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
153
+ st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis))
159
154
  else:
160
155
  # movementops are pushed to the edges with LOAD
161
- if op.op in BufferOps: st = op.arg.st
162
- else: st = sts[op.src[0]]
163
- for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
156
+ # elementwise inherits shape
157
+ st = op.arg.st if op.op in BufferOps else sts[op.src[0]]
158
+ for x in op.src:
159
+ if sts[x].shape != st.shape:
160
+ if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
161
+ raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
164
162
  sts[op] = st
165
- for i, out in enumerate(ast):
163
+ for i, out in enumerate(ast.src):
166
164
  assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
167
165
  assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
168
- assert out.arg.st.size == ast[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
169
- dfs(out, out.arg.st)
166
+ assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
167
+ assert_valid(out, out.arg.st)
168
+ shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
169
+ assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
170
+ return sts
@@ -1,8 +1,9 @@
1
- from typing import Optional, List, Tuple, Dict
1
+ from typing import Optional, List, Tuple, Dict, Callable, Any
2
2
  import functools
3
- from dataclasses import dataclass
4
- from tinygrad.helpers import getenv, to_function_name
5
- from tinygrad.codegen.uops import UOpGraph
3
+ from dataclasses import dataclass, field
4
+ from tinygrad.helpers import to_function_name, dedup
5
+ from tinygrad.codegen.uops import UOps, UOp, flops_mem
6
+ from tinygrad.ops import Op
6
7
  from tinygrad.shape.symbolic import sym_infer, sint, Variable
7
8
  from tinygrad.dtype import DType
8
9
 
@@ -12,30 +13,53 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
12
13
  dtype_in: DType # dtype for A and B
13
14
  dtype_out: DType # dtype for C and D
14
15
  threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
15
- thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
16
- thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
17
16
  def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
18
- def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
19
17
 
20
- @dataclass(frozen=True)
18
+ @dataclass
21
19
  class Program:
22
20
  name:str
23
21
  src:str
24
22
  dname:str
23
+ uops:Optional[List[UOp]]=None
24
+ mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
25
+
26
+ # filled in from uops (if we have uops)
25
27
  global_size:Optional[List[int]]=None
26
28
  local_size:Optional[List[int]]=None
27
- uops:Optional[UOpGraph]=None
28
- op_estimate:sint=0
29
- mem_estimate:sint=0
29
+ vars:List[Variable]=field(default_factory=list)
30
+ globals:List[int]=field(default_factory=list)
31
+ outs:List[int]=field(default_factory=list)
32
+ _ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
30
33
 
31
- @functools.cached_property
32
- def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
34
+ def __post_init__(self):
35
+ if not self._ran_post_init and self.uops is not None:
36
+ # single pass through the uops
37
+ for u in self.uops:
38
+ if u.op is UOps.DEFINE_VAR: self.vars.append(u.arg)
39
+ if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
40
+ if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
41
+ if u.op is UOps.SPECIAL:
42
+ # NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
43
+ if u.arg[0][0] == 'i': self.local_size = None
44
+ if u.arg[0][0] == 'l':
45
+ assert self.local_size is not None
46
+ self.local_size[int(u.arg[0][-1])] = u.arg[1]
47
+ else:
48
+ assert self.global_size is not None
49
+ self.global_size[int(u.arg[0][-1])] = u.arg[1]
50
+ self.vars = sorted(self.vars, key=lambda v: v.expr)
51
+ self.outs = sorted(dedup(self.outs))
52
+ self._ran_post_init = True
33
53
 
54
+ @property
55
+ def op_estimate(self) -> sint: return self._ops_lds[0]
56
+ @property
57
+ def lds_estimate(self) -> sint: return self._ops_lds[1]
34
58
  @functools.cached_property
35
- def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
59
+ def _ops_lds(self) -> Tuple[sint, sint]: return (0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True)
36
60
 
37
- @functools.cached_property
38
- def outcount(self) -> int: return sum(x[1] for x in self.globals)
61
+ @property
62
+ def outcount(self) -> int: return len(self.outs)
39
63
 
40
64
  @functools.cached_property
41
65
  def function_name(self) -> str: return to_function_name(self.name)
@@ -57,9 +81,7 @@ class Renderer:
57
81
  local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
58
82
  shared_max: int = 32768
59
83
  tensor_cores: List[TensorCore] = []
60
- @functools.cached_property
61
- def tc_opt(self): return getenv("TC_OPT")
62
- @functools.cached_property
63
- def tc(self): return getenv("TC", 1)
84
+ extra_matcher: Any = None
85
+ code_for_op: Dict[Op, Callable] = {}
64
86
 
65
- def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")
87
+ def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")