tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/multi.py ADDED
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+ from typing import Optional, Union, Any, Tuple, List
3
+ import functools, itertools, operator
4
+ from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING
5
+ from tinygrad.dtype import DType, ConstType
6
+ from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
7
+ from tinygrad.lazy import LazyBuffer
8
+ from tinygrad.shape.shapetracker import sint
9
+
10
+ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
11
+ assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
12
+ assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
13
+ bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
14
+
15
+ n_lbs, dim = len(lbs), prod(lbs[0].shape)
16
+ # Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
17
+ # so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
18
+ use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
19
+ if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
20
+ if not use_ring:
21
+ return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
22
+ factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0)
23
+ base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
24
+ c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
25
+ acc = 0
26
+ chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
27
+ chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
28
+
29
+ # Scatter-reduce step
30
+ for step in range(n_lbs - 1):
31
+ for i in range(len(chunks)):
32
+ s, r = (i+step)%n_lbs, (i+step+1)%n_lbs
33
+ chunked[r][i] = chunked[r][i].e(bop, chunked[s][i].copy_to_device(chunked[r][i].device, force=True))
34
+
35
+ # Allgather step
36
+ for step in range(n_lbs - 1):
37
+ for i in range(len(chunks)):
38
+ s, r = (i+step-1)%n_lbs, (i+step)%n_lbs
39
+ chunked[r][i] = chunked[s][i].copy_to_device(chunked[r][i].device, force=True)
40
+
41
+ # Assemble chunks back
42
+ pads = [((s,dim-e),) for s,e in chunks]
43
+ return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
44
+
45
+ def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
46
+ if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
47
+ sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
48
+ return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
49
+
50
+ class MultiLazyBuffer:
51
+ def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
52
+ assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
53
+ assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
54
+ self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
55
+ if axis is not None:
56
+ splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
57
+ self.bounds = [(st,ed) for st,ed in zip(splits, splits[1:])]
58
+
59
+ @property
60
+ def shape(self):
61
+ return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
62
+
63
+ @property
64
+ def size(self): return sum(x.size for x in self.real_lbs)
65
+
66
+ @property
67
+ def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
68
+
69
+ def __repr__(self):
70
+ return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
71
+
72
+ @staticmethod
73
+ def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
74
+ lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
75
+ sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
76
+ return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous() for lb in sharded_lbs], axis)
77
+
78
+ def copy_to_device(self, device:str) -> LazyBuffer:
79
+ if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device)
80
+ sz = self.lbs[0].shape[self.axis]
81
+ llbs = []
82
+ for i,lb in enumerate([lb.copy_to_device(device) for lb in self.real_lbs]):
83
+ pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape)))
84
+ llbs.append(lb.pad(pad_arg))
85
+ return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
86
+
87
+ # passthroughs
88
+ def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
89
+ def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
90
+ def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
91
+ def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
92
+ def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
93
+
94
+ # elementwise is simple
95
+ def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
96
+ msrcs = (self,)+in_srcs
97
+ assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
98
+ assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
99
+
100
+ # NOTE: they all have to share an axis, we always choose [-1]
101
+ axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
102
+ srcs = []
103
+ not_all_real = any(not all(mlb.real) for mlb in msrcs)
104
+ new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
105
+ assert any(new_real), "output contains no real lb"
106
+ for mlb in msrcs:
107
+ if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
108
+ elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
109
+ else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
110
+ # NOTE: lsrcs[-1].const(0) is correct for where
111
+ return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) if r else lsrcs[-1].const(0) for lsrcs,r in zip(zip(*srcs),new_real)], axis, new_real)
112
+
113
+ def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
114
+ return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
115
+
116
+ def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
117
+ if self.axis is not None and self.axis in axis:
118
+ # all-reduce on sharded axes
119
+ reduced_parts = [(x if r else x.const(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
120
+ if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
121
+ return MultiLazyBuffer(reduced_parts, None, self.real)
122
+ # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
123
+ return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
124
+
125
+ # *** movement ops ***
126
+
127
+ def reshape(self, arg:Tuple[sint, ...]):
128
+ if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
129
+ arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
130
+ # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
131
+ # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
132
+ new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
133
+ if arg[new_axis] != self.shape[self.axis]:
134
+ assert self.shape[self.axis] % len(self.real_lbs) == 0, f"cannot reshape on-axis for uneven shard {self.axis} {self.shape} {len(self.real_lbs)}"
135
+ assert arg[new_axis] % len(self.real_lbs) == 0, f"new on-axis shape must divide evenly between devices {new_axis} {arg} {len(self.real_lbs)}"
136
+ return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
137
+ x.shape[self.axis] if s == self.shape[self.axis] else
138
+ s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
139
+ new_axis, self.real)
140
+
141
+ def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
142
+ assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
143
+ # pad on shard axis -> fill others with zeros and set real to all True
144
+ if self.axis is not None and arg[self.axis] != (0,0):
145
+ # pad back to whole axis, remove real mask
146
+ assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time"
147
+ assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \
148
+ sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
149
+ return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
150
+ return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
151
+ def expand(self, arg:Tuple[sint, ...]):
152
+ # NOTE: this assert isn't needed, sharded axis can have dim 1
153
+ assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
154
+ return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
155
+ def permute(self, arg:Tuple[int, ...]):
156
+ # all permutes supported!
157
+ return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
158
+ def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
159
+ assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
160
+ if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
161
+ assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
162
+ idx = self.bounds.index(arg[self.axis])
163
+ # zero out other lbs to not create lb reference
164
+ return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
165
+ return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
166
+ self.axis, self.real)
167
+ def stride(self, arg:Tuple[int, ...]):
168
+ assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
169
+ return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
tinygrad/nn/__init__.py CHANGED
@@ -1,10 +1,35 @@
1
1
  import math
2
- from typing import Optional, Union, Tuple
2
+ from typing import Optional, Union, Tuple, cast
3
3
  from tinygrad.tensor import Tensor
4
4
  from tinygrad.helpers import prod
5
+ from tinygrad.nn import optim, state # noqa: F401
5
6
 
6
7
  class BatchNorm2d:
7
- def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
8
+ """
9
+ Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension).
10
+
11
+ - Described: https://paperswithcode.com/method/batch-normalization
12
+ - Paper: https://arxiv.org/abs/1502.03167v3
13
+
14
+ See: `Tensor.batchnorm`
15
+
16
+ ```python exec="true" session="tensor"
17
+ from tinygrad import Tensor, dtypes, nn
18
+ import numpy as np
19
+ np.set_printoptions(precision=4)
20
+ ```
21
+
22
+ ```python exec="true" source="above" session="tensor" result="python"
23
+ norm = nn.BatchNorm2d(3)
24
+ t = Tensor.rand(2, 3, 4, 4)
25
+ print(t.mean().item(), t.std().item())
26
+ ```
27
+ ```python exec="true" source="above" session="tensor" result="python"
28
+ t = norm(t)
29
+ print(t.mean().item(), t.std().item())
30
+ ```
31
+ """
32
+ def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
8
33
  self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
9
34
 
10
35
  if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
@@ -19,14 +44,14 @@ class BatchNorm2d:
19
44
  # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
20
45
  # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
21
46
  batch_mean = x.mean(axis=(0,2,3))
22
- y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
47
+ y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
23
48
  batch_var = (y*y).mean(axis=(0,2,3))
24
49
  batch_invstd = batch_var.add(self.eps).pow(-0.5)
25
50
 
26
51
  # NOTE: wow, this is done all throughout training in most PyTorch models
27
52
  if self.track_running_stats:
28
- self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
29
- self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() )
53
+ self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
54
+ self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape)-y.shape[1]) * batch_var.detach())
30
55
  self.num_batches_tracked += 1
31
56
  else:
32
57
  batch_mean = self.running_mean
@@ -37,43 +62,139 @@ class BatchNorm2d:
37
62
 
38
63
  # TODO: these Conv lines are terrible
39
64
  def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
65
+ """
66
+ Applies a 1D convolution over an input signal composed of several input planes.
67
+
68
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d
69
+
70
+ ```python exec="true" source="above" session="tensor" result="python"
71
+ conv = nn.Conv1d(1, 1, 3)
72
+ t = Tensor.rand(1, 1, 4)
73
+ print(t.numpy())
74
+ ```
75
+ ```python exec="true" source="above" session="tensor" result="python"
76
+ t = conv(t)
77
+ print(t.numpy())
78
+ ```
79
+ """
40
80
  return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
41
81
 
42
82
  class Conv2d:
83
+ """
84
+ Applies a 2D convolution over an input signal composed of several input planes.
85
+
86
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d
87
+
88
+ ```python exec="true" source="above" session="tensor" result="python"
89
+ conv = nn.Conv2d(1, 1, 3)
90
+ t = Tensor.rand(1, 1, 4, 4)
91
+ print(t.numpy())
92
+ ```
93
+ ```python exec="true" source="above" session="tensor" result="python"
94
+ t = conv(t)
95
+ print(t.numpy())
96
+ ```
97
+ """
43
98
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
44
99
  self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
45
100
  self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
46
- self.weight = Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
47
- bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
101
+ self.weight = self.initialize_weight(out_channels, in_channels, groups)
102
+ bound = 1 / math.sqrt(cast(int, prod(self.weight.shape[1:]))) # weight shape is always ints but mypy cannot tell
48
103
  self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
49
104
 
50
- def __call__(self, x):
105
+ def __call__(self, x:Tensor):
51
106
  return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
52
107
 
108
+ def initialize_weight(self, out_channels, in_channels, groups):
109
+ return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
110
+
53
111
  def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
112
+ """
113
+ Applies a 1D transposed convolution operator over an input signal composed of several input planes.
114
+
115
+ See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d
116
+
117
+ ```python exec="true" source="above" session="tensor" result="python"
118
+ conv = nn.ConvTranspose1d(1, 1, 3)
119
+ t = Tensor.rand(1, 1, 4)
120
+ print(t.numpy())
121
+ ```
122
+ ```python exec="true" source="above" session="tensor" result="python"
123
+ t = conv(t)
124
+ print(t.numpy())
125
+ ```
126
+ """
54
127
  return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
55
128
 
56
- class ConvTranspose2d:
129
+ class ConvTranspose2d(Conv2d):
130
+ """
131
+ Applies a 2D transposed convolution operator over an input image.
132
+
133
+ See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d
134
+
135
+ ```python exec="true" source="above" session="tensor" result="python"
136
+ conv = nn.ConvTranspose2d(1, 1, 3)
137
+ t = Tensor.rand(1, 1, 4, 4)
138
+ print(t.numpy())
139
+ ```
140
+ ```python exec="true" source="above" session="tensor" result="python"
141
+ t = conv(t)
142
+ print(t.numpy())
143
+ ```
144
+ """
57
145
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
58
- self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
59
- self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups
60
- self.weight = Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
61
- bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
62
- self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
146
+ super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
147
+ self.output_padding = output_padding
63
148
 
64
- def __call__(self, x):
65
- return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
149
+ def __call__(self, x:Tensor):
150
+ return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
151
+ dilation=self.dilation, groups=self.groups)
152
+
153
+ def initialize_weight(self, out_channels, in_channels, groups):
154
+ return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
66
155
 
67
156
  class Linear:
157
+ """
158
+ Applies a linear transformation to the incoming data.
159
+
160
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Linear
161
+
162
+ ```python exec="true" source="above" session="tensor" result="python"
163
+ lin = nn.Linear(3, 4)
164
+ t = Tensor.rand(2, 3)
165
+ print(t.numpy())
166
+ ```
167
+ ```python exec="true" source="above" session="tensor" result="python"
168
+ t = lin(t)
169
+ print(t.numpy())
170
+ ```
171
+ """
68
172
  def __init__(self, in_features, out_features, bias=True):
173
+ # TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
69
174
  self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
70
- bound = 1 / math.sqrt(self.weight.shape[1])
175
+ bound = 1 / math.sqrt(in_features)
71
176
  self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
72
177
 
73
- def __call__(self, x):
178
+ def __call__(self, x:Tensor):
74
179
  return x.linear(self.weight.transpose(), self.bias)
75
180
 
76
181
  class GroupNorm:
182
+ """
183
+ Applies Group Normalization over a mini-batch of inputs.
184
+
185
+ - Described: https://paperswithcode.com/method/group-normalization
186
+ - Paper: https://arxiv.org/abs/1803.08494v3
187
+
188
+ ```python exec="true" source="above" session="tensor" result="python"
189
+ norm = nn.GroupNorm(2, 12)
190
+ t = Tensor.rand(2, 12, 4, 4) * 2 + 1
191
+ print(t.mean().item(), t.std().item())
192
+ ```
193
+ ```python exec="true" source="above" session="tensor" result="python"
194
+ t = norm(t)
195
+ print(t.mean().item(), t.std().item())
196
+ ```
197
+ """
77
198
  def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
78
199
  self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
79
200
  self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
@@ -89,6 +210,22 @@ class GroupNorm:
89
210
  return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
90
211
 
91
212
  class InstanceNorm:
213
+ """
214
+ Applies Instance Normalization over a mini-batch of inputs.
215
+
216
+ - Described: https://paperswithcode.com/method/instance-normalization
217
+ - Paper: https://arxiv.org/abs/1607.08022v3
218
+
219
+ ```python exec="true" source="above" session="tensor" result="python"
220
+ norm = nn.InstanceNorm(3)
221
+ t = Tensor.rand(2, 3, 4, 4) * 2 + 1
222
+ print(t.mean().item(), t.std().item())
223
+ ```
224
+ ```python exec="true" source="above" session="tensor" result="python"
225
+ t = norm(t)
226
+ print(t.mean().item(), t.std().item())
227
+ ```
228
+ """
92
229
  def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
93
230
  self.num_features, self.eps = num_features, eps
94
231
  self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
@@ -100,6 +237,22 @@ class InstanceNorm:
100
237
  return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
101
238
 
102
239
  class LayerNorm:
240
+ """
241
+ Applies Layer Normalization over a mini-batch of inputs.
242
+
243
+ - Described: https://paperswithcode.com/method/layer-normalization
244
+ - Paper: https://arxiv.org/abs/1607.06450v1
245
+
246
+ ```python exec="true" source="above" session="tensor" result="python"
247
+ norm = nn.LayerNorm(3)
248
+ t = Tensor.rand(2, 5, 3) * 2 + 1
249
+ print(t.mean().item(), t.std().item())
250
+ ```
251
+ ```python exec="true" source="above" session="tensor" result="python"
252
+ t = norm(t)
253
+ print(t.mean().item(), t.std().item())
254
+ ```
255
+ """
103
256
  def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
104
257
  self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
105
258
  self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
@@ -112,13 +265,40 @@ class LayerNorm:
112
265
  return x * self.weight + self.bias
113
266
 
114
267
  class LayerNorm2d(LayerNorm):
268
+ """
269
+ Applies Layer Normalization over a mini-batch of 2D inputs.
270
+
271
+ See: `LayerNorm`
272
+
273
+ ```python exec="true" source="above" session="tensor" result="python"
274
+ norm = nn.LayerNorm2d(3)
275
+ t = Tensor.rand(2, 3, 4, 4) * 2 + 1
276
+ print(t.mean().item(), t.std().item())
277
+ ```
278
+ ```python exec="true" source="above" session="tensor" result="python"
279
+ t = norm(t)
280
+ print(t.mean().item(), t.std().item())
281
+ ```
282
+ """
115
283
  def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
116
284
 
117
285
  class Embedding:
286
+ """
287
+ A simple lookup table that stores embeddings of a fixed dictionary and size.
288
+
289
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Embedding
290
+
291
+ ```python exec="true" source="above" session="tensor" result="python"
292
+ emb = nn.Embedding(10, 3)
293
+ print(emb(Tensor([1, 2, 3, 1])).numpy())
294
+ ```
295
+ """
118
296
  def __init__(self, vocab_size:int, embed_size:int):
119
- self.vocab_size = vocab_size
120
- self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
297
+ self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
121
298
 
122
299
  def __call__(self, idx:Tensor) -> Tensor:
123
- if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size)
124
- return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight
300
+ if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
301
+ arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
302
+ if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
303
+ arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
304
+ return (arange == idx).mul(vals).sum(2)
@@ -0,0 +1,7 @@
1
+ import gzip
2
+ from tinygrad import Tensor, fetch
3
+
4
+ def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
5
+ def mnist():
6
+ return _fetch_mnist("train-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("train-labels-idx1-ubyte.gz", 8), \
7
+ _fetch_mnist("t10k-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("t10k-labels-idx1-ubyte.gz", 8)
tinygrad/nn/optim.py CHANGED
@@ -1,64 +1,144 @@
1
1
  # sorted in order of increasing complexity
2
2
  from typing import List
3
- from tinygrad.helpers import dedup
3
+ from tinygrad.helpers import dedup, flatten, getenv
4
4
  from tinygrad.tensor import Tensor
5
+ from tinygrad.dtype import dtypes, least_upper_dtype
5
6
 
6
7
  class Optimizer:
8
+ """
9
+ Base class for all optimizers.
10
+ """
7
11
  def __init__(self, params: List[Tensor], lr: float):
8
12
  # if it's None, but being put into an optimizer, set it to True
9
13
  for x in params:
10
14
  if x.requires_grad is None: x.requires_grad = True
11
15
 
12
16
  self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
17
+ assert len(self.params) != 0, "optimizer must have at least one param"
18
+ self.device = self.params[0].device
13
19
  self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
14
- self.lr = Tensor([lr], requires_grad=False).contiguous()
20
+ # store lr in at least float32 precision
21
+ self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
22
+ dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
15
23
 
16
24
  def zero_grad(self):
25
+ """
26
+ Zeroes the gradients of all the parameters.
27
+ """
17
28
  for param in self.params: param.grad = None
18
29
 
19
- def realize(self, extra=None):
20
- # TODO: corealize
21
- # NOTE: in extra is too late for most of the params due to issues with assign
22
- for p in extra + self.params + self.buffers if extra is not None else self.params + self.buffers:
23
- p.realize()
30
+ def step(self):
31
+ """
32
+ Performs a single optimization step.
33
+ """
34
+ Tensor.realize(*self.schedule_step())
35
+ def schedule_step(self) -> List[Tensor]:
36
+ """
37
+ Returns the tensors that need to be realized to perform a single optimization step.
38
+ """
39
+ assert Tensor.training, (
40
+ f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
41
+ - help: Consider setting Tensor.training=True before calling Optimizer.step().""")
42
+ return self._step()+self.params+self.buffers
43
+ def _step(self) -> List[Tensor]: raise NotImplementedError
24
44
 
25
- class SGD(Optimizer):
26
- def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
45
+ class OptimizerGroup(Optimizer):
46
+ """
47
+ Combines multiple optimizers into one.
48
+ """
49
+ def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called
50
+ self.optimizers = optimizers
51
+ self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
52
+ def __getitem__(self, i): return self.optimizers[i]
53
+ def zero_grad(self): [o.zero_grad() for o in self.optimizers]
54
+ def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]
55
+
56
+ # LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
57
+ def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
58
+ """
59
+ Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
60
+
61
+ `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
62
+
63
+ - Described: https://paperswithcode.com/method/sgd
64
+ """
65
+ return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0)
66
+
67
+ class LARS(Optimizer):
68
+ """
69
+ Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.
70
+
71
+ - Described: https://paperswithcode.com/method/lars
72
+ - Paper: https://arxiv.org/abs/1708.03888v3
73
+ """
74
+ def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
27
75
  super().__init__(params, lr)
28
- self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
29
- self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
76
+ self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
77
+ self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
30
78
 
31
- # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
32
- def step(self) -> None:
79
+ def _step(self) -> List[Tensor]:
33
80
  for i, t in enumerate(self.params):
34
81
  assert t.grad is not None
35
- g = t.grad.realize() + self.wd * t.detach()
82
+ # contiguous is needed since the grads can allegedly form a "diamond"
83
+ # TODO: fix this in lazy.py
84
+ g = t.grad.contiguous()
85
+ if self.tcoef != 0:
86
+ r1 = t.detach().square().sum().sqrt()
87
+ r2 = g.square().sum().sqrt()
88
+ r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
89
+ else: r = 1.0
90
+ g = g + self.wd * t.detach()
91
+ # classic momentum does post learning rate update
92
+ if self.classic: g = g * r * self.lr
36
93
  if self.momentum:
37
- self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
94
+ self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
38
95
  g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
39
- t.assign(t.detach() - g * self.lr)
40
- self.realize(self.b)
96
+ # popular momentum does pre learning rate update
97
+ if not self.classic: g = g * r * self.lr
98
+ t.assign((t.detach() - g).cast(t.dtype))
99
+ return self.b
41
100
 
42
101
  # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
43
- def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True)
44
- def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
102
+ def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
103
+ """
104
+ AdamW optimizer with optional weight decay.
105
+
106
+ - Described: https://paperswithcode.com/method/adamw
107
+ - Paper: https://arxiv.org/abs/1711.05101v3
108
+ """
109
+ return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)
110
+ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
111
+ """
112
+ Adam optimizer.
113
+
114
+ - Described: https://paperswithcode.com/method/adam
115
+ - Paper: https://arxiv.org/abs/1412.6980
116
+ """
117
+ return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
45
118
 
46
119
  class LAMB(Optimizer):
47
- def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
120
+ """
121
+ LAMB optimizer with optional weight decay.
122
+
123
+ - Described: https://paperswithcode.com/method/lamb
124
+ - Paper: https://arxiv.org/abs/1904.00962
125
+ """
126
+ def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
48
127
  super().__init__(params, lr)
49
- self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize()
50
- self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
51
- self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
128
+ self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
129
+ self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2])
130
+ self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
131
+ self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
52
132
 
53
- def step(self) -> None:
54
- self.t.assign(self.t + 1).realize()
133
+ def _step(self) -> List[Tensor]:
134
+ self.b1_t *= self.b1
135
+ self.b2_t *= self.b2
55
136
  for i, t in enumerate(self.params):
56
137
  assert t.grad is not None
57
- g = t.grad.realize()
58
- self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
59
- self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
60
- m_hat = self.m[i] / (1.0 - self.b1**self.t)
61
- v_hat = self.v[i] / (1.0 - self.b2**self.t)
138
+ self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
139
+ self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
140
+ m_hat = self.m[i] / (1.0 - self.b1_t)
141
+ v_hat = self.v[i] / (1.0 - self.b2_t)
62
142
  up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
63
143
  if not self.adam:
64
144
  r1 = t.detach().square().sum().sqrt()
@@ -66,5 +146,5 @@ class LAMB(Optimizer):
66
146
  r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
67
147
  else:
68
148
  r = 1.0
69
- t.assign(t.detach() - self.lr * r * up)
70
- self.realize([self.t] + self.m + self.v)
149
+ t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
150
+ return [self.b1_t, self.b2_t] + self.m + self.v