tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/multi.py
ADDED
@@ -0,0 +1,169 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Optional, Union, Any, Tuple, List
|
3
|
+
import functools, itertools, operator
|
4
|
+
from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING
|
5
|
+
from tinygrad.dtype import DType, ConstType
|
6
|
+
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
|
7
|
+
from tinygrad.lazy import LazyBuffer
|
8
|
+
from tinygrad.shape.shapetracker import sint
|
9
|
+
|
10
|
+
def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
11
|
+
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
12
|
+
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
13
|
+
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
|
14
|
+
|
15
|
+
n_lbs, dim = len(lbs), prod(lbs[0].shape)
|
16
|
+
# Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
17
|
+
# so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
18
|
+
use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
|
19
|
+
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
|
20
|
+
if not use_ring:
|
21
|
+
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
22
|
+
factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0)
|
23
|
+
base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
|
24
|
+
c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
|
25
|
+
acc = 0
|
26
|
+
chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
|
27
|
+
chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
28
|
+
|
29
|
+
# Scatter-reduce step
|
30
|
+
for step in range(n_lbs - 1):
|
31
|
+
for i in range(len(chunks)):
|
32
|
+
s, r = (i+step)%n_lbs, (i+step+1)%n_lbs
|
33
|
+
chunked[r][i] = chunked[r][i].e(bop, chunked[s][i].copy_to_device(chunked[r][i].device, force=True))
|
34
|
+
|
35
|
+
# Allgather step
|
36
|
+
for step in range(n_lbs - 1):
|
37
|
+
for i in range(len(chunks)):
|
38
|
+
s, r = (i+step-1)%n_lbs, (i+step)%n_lbs
|
39
|
+
chunked[r][i] = chunked[s][i].copy_to_device(chunked[r][i].device, force=True)
|
40
|
+
|
41
|
+
# Assemble chunks back
|
42
|
+
pads = [((s,dim-e),) for s,e in chunks]
|
43
|
+
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
|
44
|
+
|
45
|
+
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
46
|
+
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
47
|
+
sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
|
48
|
+
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
49
|
+
|
50
|
+
class MultiLazyBuffer:
|
51
|
+
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
|
52
|
+
assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
|
53
|
+
assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
|
54
|
+
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
|
55
|
+
if axis is not None:
|
56
|
+
splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
|
57
|
+
self.bounds = [(st,ed) for st,ed in zip(splits, splits[1:])]
|
58
|
+
|
59
|
+
@property
|
60
|
+
def shape(self):
|
61
|
+
return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
62
|
+
|
63
|
+
@property
|
64
|
+
def size(self): return sum(x.size for x in self.real_lbs)
|
65
|
+
|
66
|
+
@property
|
67
|
+
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
|
68
|
+
|
69
|
+
def __repr__(self):
|
70
|
+
return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
71
|
+
|
72
|
+
@staticmethod
|
73
|
+
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
|
74
|
+
lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
|
75
|
+
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
|
76
|
+
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous() for lb in sharded_lbs], axis)
|
77
|
+
|
78
|
+
def copy_to_device(self, device:str) -> LazyBuffer:
|
79
|
+
if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device)
|
80
|
+
sz = self.lbs[0].shape[self.axis]
|
81
|
+
llbs = []
|
82
|
+
for i,lb in enumerate([lb.copy_to_device(device) for lb in self.real_lbs]):
|
83
|
+
pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape)))
|
84
|
+
llbs.append(lb.pad(pad_arg))
|
85
|
+
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
|
86
|
+
|
87
|
+
# passthroughs
|
88
|
+
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
|
89
|
+
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
90
|
+
def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
|
91
|
+
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
92
|
+
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
93
|
+
|
94
|
+
# elementwise is simple
|
95
|
+
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
|
96
|
+
msrcs = (self,)+in_srcs
|
97
|
+
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
|
98
|
+
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
99
|
+
|
100
|
+
# NOTE: they all have to share an axis, we always choose [-1]
|
101
|
+
axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
|
102
|
+
srcs = []
|
103
|
+
not_all_real = any(not all(mlb.real) for mlb in msrcs)
|
104
|
+
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
|
105
|
+
assert any(new_real), "output contains no real lb"
|
106
|
+
for mlb in msrcs:
|
107
|
+
if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
|
108
|
+
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
|
109
|
+
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
|
110
|
+
# NOTE: lsrcs[-1].const(0) is correct for where
|
111
|
+
return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) if r else lsrcs[-1].const(0) for lsrcs,r in zip(zip(*srcs),new_real)], axis, new_real)
|
112
|
+
|
113
|
+
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
|
114
|
+
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
115
|
+
|
116
|
+
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
117
|
+
if self.axis is not None and self.axis in axis:
|
118
|
+
# all-reduce on sharded axes
|
119
|
+
reduced_parts = [(x if r else x.const(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
|
120
|
+
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
|
121
|
+
return MultiLazyBuffer(reduced_parts, None, self.real)
|
122
|
+
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
123
|
+
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
|
124
|
+
|
125
|
+
# *** movement ops ***
|
126
|
+
|
127
|
+
def reshape(self, arg:Tuple[sint, ...]):
|
128
|
+
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
|
129
|
+
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
130
|
+
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
131
|
+
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
132
|
+
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
|
133
|
+
if arg[new_axis] != self.shape[self.axis]:
|
134
|
+
assert self.shape[self.axis] % len(self.real_lbs) == 0, f"cannot reshape on-axis for uneven shard {self.axis} {self.shape} {len(self.real_lbs)}"
|
135
|
+
assert arg[new_axis] % len(self.real_lbs) == 0, f"new on-axis shape must divide evenly between devices {new_axis} {arg} {len(self.real_lbs)}"
|
136
|
+
return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
|
137
|
+
x.shape[self.axis] if s == self.shape[self.axis] else
|
138
|
+
s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
|
139
|
+
new_axis, self.real)
|
140
|
+
|
141
|
+
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
142
|
+
assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
|
143
|
+
# pad on shard axis -> fill others with zeros and set real to all True
|
144
|
+
if self.axis is not None and arg[self.axis] != (0,0):
|
145
|
+
# pad back to whole axis, remove real mask
|
146
|
+
assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time"
|
147
|
+
assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \
|
148
|
+
sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
|
149
|
+
return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
|
150
|
+
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
|
151
|
+
def expand(self, arg:Tuple[sint, ...]):
|
152
|
+
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
153
|
+
assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
|
154
|
+
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
|
155
|
+
def permute(self, arg:Tuple[int, ...]):
|
156
|
+
# all permutes supported!
|
157
|
+
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
|
158
|
+
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
|
159
|
+
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
|
160
|
+
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
|
161
|
+
assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
|
162
|
+
idx = self.bounds.index(arg[self.axis])
|
163
|
+
# zero out other lbs to not create lb reference
|
164
|
+
return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
|
165
|
+
return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
|
166
|
+
self.axis, self.real)
|
167
|
+
def stride(self, arg:Tuple[int, ...]):
|
168
|
+
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
|
169
|
+
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
|
tinygrad/nn/__init__.py
CHANGED
@@ -1,10 +1,35 @@
|
|
1
1
|
import math
|
2
|
-
from typing import Optional, Union, Tuple
|
2
|
+
from typing import Optional, Union, Tuple, cast
|
3
3
|
from tinygrad.tensor import Tensor
|
4
4
|
from tinygrad.helpers import prod
|
5
|
+
from tinygrad.nn import optim, state # noqa: F401
|
5
6
|
|
6
7
|
class BatchNorm2d:
|
7
|
-
|
8
|
+
"""
|
9
|
+
Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension).
|
10
|
+
|
11
|
+
- Described: https://paperswithcode.com/method/batch-normalization
|
12
|
+
- Paper: https://arxiv.org/abs/1502.03167v3
|
13
|
+
|
14
|
+
See: `Tensor.batchnorm`
|
15
|
+
|
16
|
+
```python exec="true" session="tensor"
|
17
|
+
from tinygrad import Tensor, dtypes, nn
|
18
|
+
import numpy as np
|
19
|
+
np.set_printoptions(precision=4)
|
20
|
+
```
|
21
|
+
|
22
|
+
```python exec="true" source="above" session="tensor" result="python"
|
23
|
+
norm = nn.BatchNorm2d(3)
|
24
|
+
t = Tensor.rand(2, 3, 4, 4)
|
25
|
+
print(t.mean().item(), t.std().item())
|
26
|
+
```
|
27
|
+
```python exec="true" source="above" session="tensor" result="python"
|
28
|
+
t = norm(t)
|
29
|
+
print(t.mean().item(), t.std().item())
|
30
|
+
```
|
31
|
+
"""
|
32
|
+
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
8
33
|
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
9
34
|
|
10
35
|
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
|
@@ -19,14 +44,14 @@ class BatchNorm2d:
|
|
19
44
|
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
20
45
|
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
21
46
|
batch_mean = x.mean(axis=(0,2,3))
|
22
|
-
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
47
|
+
y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
|
23
48
|
batch_var = (y*y).mean(axis=(0,2,3))
|
24
49
|
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
25
50
|
|
26
51
|
# NOTE: wow, this is done all throughout training in most PyTorch models
|
27
52
|
if self.track_running_stats:
|
28
|
-
self.running_mean.assign((1
|
29
|
-
self.running_var.assign((1
|
53
|
+
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
|
54
|
+
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape)-y.shape[1]) * batch_var.detach())
|
30
55
|
self.num_batches_tracked += 1
|
31
56
|
else:
|
32
57
|
batch_mean = self.running_mean
|
@@ -37,43 +62,139 @@ class BatchNorm2d:
|
|
37
62
|
|
38
63
|
# TODO: these Conv lines are terrible
|
39
64
|
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
65
|
+
"""
|
66
|
+
Applies a 1D convolution over an input signal composed of several input planes.
|
67
|
+
|
68
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d
|
69
|
+
|
70
|
+
```python exec="true" source="above" session="tensor" result="python"
|
71
|
+
conv = nn.Conv1d(1, 1, 3)
|
72
|
+
t = Tensor.rand(1, 1, 4)
|
73
|
+
print(t.numpy())
|
74
|
+
```
|
75
|
+
```python exec="true" source="above" session="tensor" result="python"
|
76
|
+
t = conv(t)
|
77
|
+
print(t.numpy())
|
78
|
+
```
|
79
|
+
"""
|
40
80
|
return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
|
41
81
|
|
42
82
|
class Conv2d:
|
83
|
+
"""
|
84
|
+
Applies a 2D convolution over an input signal composed of several input planes.
|
85
|
+
|
86
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d
|
87
|
+
|
88
|
+
```python exec="true" source="above" session="tensor" result="python"
|
89
|
+
conv = nn.Conv2d(1, 1, 3)
|
90
|
+
t = Tensor.rand(1, 1, 4, 4)
|
91
|
+
print(t.numpy())
|
92
|
+
```
|
93
|
+
```python exec="true" source="above" session="tensor" result="python"
|
94
|
+
t = conv(t)
|
95
|
+
print(t.numpy())
|
96
|
+
```
|
97
|
+
"""
|
43
98
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
44
99
|
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
45
100
|
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
46
|
-
self.weight =
|
47
|
-
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
|
101
|
+
self.weight = self.initialize_weight(out_channels, in_channels, groups)
|
102
|
+
bound = 1 / math.sqrt(cast(int, prod(self.weight.shape[1:]))) # weight shape is always ints but mypy cannot tell
|
48
103
|
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
49
104
|
|
50
|
-
def __call__(self, x):
|
105
|
+
def __call__(self, x:Tensor):
|
51
106
|
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
52
107
|
|
108
|
+
def initialize_weight(self, out_channels, in_channels, groups):
|
109
|
+
return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
110
|
+
|
53
111
|
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
112
|
+
"""
|
113
|
+
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
|
114
|
+
|
115
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d
|
116
|
+
|
117
|
+
```python exec="true" source="above" session="tensor" result="python"
|
118
|
+
conv = nn.ConvTranspose1d(1, 1, 3)
|
119
|
+
t = Tensor.rand(1, 1, 4)
|
120
|
+
print(t.numpy())
|
121
|
+
```
|
122
|
+
```python exec="true" source="above" session="tensor" result="python"
|
123
|
+
t = conv(t)
|
124
|
+
print(t.numpy())
|
125
|
+
```
|
126
|
+
"""
|
54
127
|
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
|
55
128
|
|
56
|
-
class ConvTranspose2d:
|
129
|
+
class ConvTranspose2d(Conv2d):
|
130
|
+
"""
|
131
|
+
Applies a 2D transposed convolution operator over an input image.
|
132
|
+
|
133
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d
|
134
|
+
|
135
|
+
```python exec="true" source="above" session="tensor" result="python"
|
136
|
+
conv = nn.ConvTranspose2d(1, 1, 3)
|
137
|
+
t = Tensor.rand(1, 1, 4, 4)
|
138
|
+
print(t.numpy())
|
139
|
+
```
|
140
|
+
```python exec="true" source="above" session="tensor" result="python"
|
141
|
+
t = conv(t)
|
142
|
+
print(t.numpy())
|
143
|
+
```
|
144
|
+
"""
|
57
145
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
58
|
-
|
59
|
-
self.
|
60
|
-
self.weight = Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
61
|
-
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
|
62
|
-
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
146
|
+
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
147
|
+
self.output_padding = output_padding
|
63
148
|
|
64
|
-
def __call__(self, x):
|
65
|
-
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
|
149
|
+
def __call__(self, x:Tensor):
|
150
|
+
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
|
151
|
+
dilation=self.dilation, groups=self.groups)
|
152
|
+
|
153
|
+
def initialize_weight(self, out_channels, in_channels, groups):
|
154
|
+
return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
66
155
|
|
67
156
|
class Linear:
|
157
|
+
"""
|
158
|
+
Applies a linear transformation to the incoming data.
|
159
|
+
|
160
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Linear
|
161
|
+
|
162
|
+
```python exec="true" source="above" session="tensor" result="python"
|
163
|
+
lin = nn.Linear(3, 4)
|
164
|
+
t = Tensor.rand(2, 3)
|
165
|
+
print(t.numpy())
|
166
|
+
```
|
167
|
+
```python exec="true" source="above" session="tensor" result="python"
|
168
|
+
t = lin(t)
|
169
|
+
print(t.numpy())
|
170
|
+
```
|
171
|
+
"""
|
68
172
|
def __init__(self, in_features, out_features, bias=True):
|
173
|
+
# TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
|
69
174
|
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
70
|
-
bound = 1 / math.sqrt(
|
175
|
+
bound = 1 / math.sqrt(in_features)
|
71
176
|
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
72
177
|
|
73
|
-
def __call__(self, x):
|
178
|
+
def __call__(self, x:Tensor):
|
74
179
|
return x.linear(self.weight.transpose(), self.bias)
|
75
180
|
|
76
181
|
class GroupNorm:
|
182
|
+
"""
|
183
|
+
Applies Group Normalization over a mini-batch of inputs.
|
184
|
+
|
185
|
+
- Described: https://paperswithcode.com/method/group-normalization
|
186
|
+
- Paper: https://arxiv.org/abs/1803.08494v3
|
187
|
+
|
188
|
+
```python exec="true" source="above" session="tensor" result="python"
|
189
|
+
norm = nn.GroupNorm(2, 12)
|
190
|
+
t = Tensor.rand(2, 12, 4, 4) * 2 + 1
|
191
|
+
print(t.mean().item(), t.std().item())
|
192
|
+
```
|
193
|
+
```python exec="true" source="above" session="tensor" result="python"
|
194
|
+
t = norm(t)
|
195
|
+
print(t.mean().item(), t.std().item())
|
196
|
+
```
|
197
|
+
"""
|
77
198
|
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
|
78
199
|
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
79
200
|
self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
@@ -89,6 +210,22 @@ class GroupNorm:
|
|
89
210
|
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
90
211
|
|
91
212
|
class InstanceNorm:
|
213
|
+
"""
|
214
|
+
Applies Instance Normalization over a mini-batch of inputs.
|
215
|
+
|
216
|
+
- Described: https://paperswithcode.com/method/instance-normalization
|
217
|
+
- Paper: https://arxiv.org/abs/1607.08022v3
|
218
|
+
|
219
|
+
```python exec="true" source="above" session="tensor" result="python"
|
220
|
+
norm = nn.InstanceNorm(3)
|
221
|
+
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
|
222
|
+
print(t.mean().item(), t.std().item())
|
223
|
+
```
|
224
|
+
```python exec="true" source="above" session="tensor" result="python"
|
225
|
+
t = norm(t)
|
226
|
+
print(t.mean().item(), t.std().item())
|
227
|
+
```
|
228
|
+
"""
|
92
229
|
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
|
93
230
|
self.num_features, self.eps = num_features, eps
|
94
231
|
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
|
@@ -100,6 +237,22 @@ class InstanceNorm:
|
|
100
237
|
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
101
238
|
|
102
239
|
class LayerNorm:
|
240
|
+
"""
|
241
|
+
Applies Layer Normalization over a mini-batch of inputs.
|
242
|
+
|
243
|
+
- Described: https://paperswithcode.com/method/layer-normalization
|
244
|
+
- Paper: https://arxiv.org/abs/1607.06450v1
|
245
|
+
|
246
|
+
```python exec="true" source="above" session="tensor" result="python"
|
247
|
+
norm = nn.LayerNorm(3)
|
248
|
+
t = Tensor.rand(2, 5, 3) * 2 + 1
|
249
|
+
print(t.mean().item(), t.std().item())
|
250
|
+
```
|
251
|
+
```python exec="true" source="above" session="tensor" result="python"
|
252
|
+
t = norm(t)
|
253
|
+
print(t.mean().item(), t.std().item())
|
254
|
+
```
|
255
|
+
"""
|
103
256
|
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
|
104
257
|
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
105
258
|
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
@@ -112,13 +265,40 @@ class LayerNorm:
|
|
112
265
|
return x * self.weight + self.bias
|
113
266
|
|
114
267
|
class LayerNorm2d(LayerNorm):
|
268
|
+
"""
|
269
|
+
Applies Layer Normalization over a mini-batch of 2D inputs.
|
270
|
+
|
271
|
+
See: `LayerNorm`
|
272
|
+
|
273
|
+
```python exec="true" source="above" session="tensor" result="python"
|
274
|
+
norm = nn.LayerNorm2d(3)
|
275
|
+
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
|
276
|
+
print(t.mean().item(), t.std().item())
|
277
|
+
```
|
278
|
+
```python exec="true" source="above" session="tensor" result="python"
|
279
|
+
t = norm(t)
|
280
|
+
print(t.mean().item(), t.std().item())
|
281
|
+
```
|
282
|
+
"""
|
115
283
|
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
116
284
|
|
117
285
|
class Embedding:
|
286
|
+
"""
|
287
|
+
A simple lookup table that stores embeddings of a fixed dictionary and size.
|
288
|
+
|
289
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Embedding
|
290
|
+
|
291
|
+
```python exec="true" source="above" session="tensor" result="python"
|
292
|
+
emb = nn.Embedding(10, 3)
|
293
|
+
print(emb(Tensor([1, 2, 3, 1])).numpy())
|
294
|
+
```
|
295
|
+
"""
|
118
296
|
def __init__(self, vocab_size:int, embed_size:int):
|
119
|
-
self.
|
120
|
-
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
|
297
|
+
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
|
121
298
|
|
122
299
|
def __call__(self, idx:Tensor) -> Tensor:
|
123
|
-
if
|
124
|
-
|
300
|
+
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
|
301
|
+
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
302
|
+
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
303
|
+
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
|
304
|
+
return (arange == idx).mul(vals).sum(2)
|
tinygrad/nn/datasets.py
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
import gzip
|
2
|
+
from tinygrad import Tensor, fetch
|
3
|
+
|
4
|
+
def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
|
5
|
+
def mnist():
|
6
|
+
return _fetch_mnist("train-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("train-labels-idx1-ubyte.gz", 8), \
|
7
|
+
_fetch_mnist("t10k-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("t10k-labels-idx1-ubyte.gz", 8)
|
tinygrad/nn/optim.py
CHANGED
@@ -1,64 +1,144 @@
|
|
1
1
|
# sorted in order of increasing complexity
|
2
2
|
from typing import List
|
3
|
-
from tinygrad.helpers import dedup
|
3
|
+
from tinygrad.helpers import dedup, flatten, getenv
|
4
4
|
from tinygrad.tensor import Tensor
|
5
|
+
from tinygrad.dtype import dtypes, least_upper_dtype
|
5
6
|
|
6
7
|
class Optimizer:
|
8
|
+
"""
|
9
|
+
Base class for all optimizers.
|
10
|
+
"""
|
7
11
|
def __init__(self, params: List[Tensor], lr: float):
|
8
12
|
# if it's None, but being put into an optimizer, set it to True
|
9
13
|
for x in params:
|
10
14
|
if x.requires_grad is None: x.requires_grad = True
|
11
15
|
|
12
16
|
self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
|
17
|
+
assert len(self.params) != 0, "optimizer must have at least one param"
|
18
|
+
self.device = self.params[0].device
|
13
19
|
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
14
|
-
|
20
|
+
# store lr in at least float32 precision
|
21
|
+
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
|
22
|
+
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
|
15
23
|
|
16
24
|
def zero_grad(self):
|
25
|
+
"""
|
26
|
+
Zeroes the gradients of all the parameters.
|
27
|
+
"""
|
17
28
|
for param in self.params: param.grad = None
|
18
29
|
|
19
|
-
def
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
30
|
+
def step(self):
|
31
|
+
"""
|
32
|
+
Performs a single optimization step.
|
33
|
+
"""
|
34
|
+
Tensor.realize(*self.schedule_step())
|
35
|
+
def schedule_step(self) -> List[Tensor]:
|
36
|
+
"""
|
37
|
+
Returns the tensors that need to be realized to perform a single optimization step.
|
38
|
+
"""
|
39
|
+
assert Tensor.training, (
|
40
|
+
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
|
41
|
+
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
|
42
|
+
return self._step()+self.params+self.buffers
|
43
|
+
def _step(self) -> List[Tensor]: raise NotImplementedError
|
24
44
|
|
25
|
-
class
|
26
|
-
|
45
|
+
class OptimizerGroup(Optimizer):
|
46
|
+
"""
|
47
|
+
Combines multiple optimizers into one.
|
48
|
+
"""
|
49
|
+
def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called
|
50
|
+
self.optimizers = optimizers
|
51
|
+
self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
|
52
|
+
def __getitem__(self, i): return self.optimizers[i]
|
53
|
+
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
|
54
|
+
def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]
|
55
|
+
|
56
|
+
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
|
57
|
+
def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
|
58
|
+
"""
|
59
|
+
Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
|
60
|
+
|
61
|
+
`classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
|
62
|
+
|
63
|
+
- Described: https://paperswithcode.com/method/sgd
|
64
|
+
"""
|
65
|
+
return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0)
|
66
|
+
|
67
|
+
class LARS(Optimizer):
|
68
|
+
"""
|
69
|
+
Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.
|
70
|
+
|
71
|
+
- Described: https://paperswithcode.com/method/lars
|
72
|
+
- Paper: https://arxiv.org/abs/1708.03888v3
|
73
|
+
"""
|
74
|
+
def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
|
27
75
|
super().__init__(params, lr)
|
28
|
-
self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
|
29
|
-
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
76
|
+
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
|
77
|
+
self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
30
78
|
|
31
|
-
|
32
|
-
def step(self) -> None:
|
79
|
+
def _step(self) -> List[Tensor]:
|
33
80
|
for i, t in enumerate(self.params):
|
34
81
|
assert t.grad is not None
|
35
|
-
|
82
|
+
# contiguous is needed since the grads can allegedly form a "diamond"
|
83
|
+
# TODO: fix this in lazy.py
|
84
|
+
g = t.grad.contiguous()
|
85
|
+
if self.tcoef != 0:
|
86
|
+
r1 = t.detach().square().sum().sqrt()
|
87
|
+
r2 = g.square().sum().sqrt()
|
88
|
+
r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
|
89
|
+
else: r = 1.0
|
90
|
+
g = g + self.wd * t.detach()
|
91
|
+
# classic momentum does post learning rate update
|
92
|
+
if self.classic: g = g * r * self.lr
|
36
93
|
if self.momentum:
|
37
|
-
self.b[i].assign(self.momentum * self.b[i] + g)
|
94
|
+
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
|
38
95
|
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
39
|
-
|
40
|
-
|
96
|
+
# popular momentum does pre learning rate update
|
97
|
+
if not self.classic: g = g * r * self.lr
|
98
|
+
t.assign((t.detach() - g).cast(t.dtype))
|
99
|
+
return self.b
|
41
100
|
|
42
101
|
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
|
43
|
-
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8,
|
44
|
-
|
102
|
+
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
|
103
|
+
"""
|
104
|
+
AdamW optimizer with optional weight decay.
|
105
|
+
|
106
|
+
- Described: https://paperswithcode.com/method/adamw
|
107
|
+
- Paper: https://arxiv.org/abs/1711.05101v3
|
108
|
+
"""
|
109
|
+
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)
|
110
|
+
def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
111
|
+
"""
|
112
|
+
Adam optimizer.
|
113
|
+
|
114
|
+
- Described: https://paperswithcode.com/method/adam
|
115
|
+
- Paper: https://arxiv.org/abs/1412.6980
|
116
|
+
"""
|
117
|
+
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
|
45
118
|
|
46
119
|
class LAMB(Optimizer):
|
47
|
-
|
120
|
+
"""
|
121
|
+
LAMB optimizer with optional weight decay.
|
122
|
+
|
123
|
+
- Described: https://paperswithcode.com/method/lamb
|
124
|
+
- Paper: https://arxiv.org/abs/1904.00962
|
125
|
+
"""
|
126
|
+
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
|
48
127
|
super().__init__(params, lr)
|
49
|
-
self.b1, self.b2, self.eps, self.wd, self.adam
|
50
|
-
self.
|
51
|
-
self.
|
128
|
+
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
|
129
|
+
self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2])
|
130
|
+
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
131
|
+
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
52
132
|
|
53
|
-
def
|
54
|
-
self.
|
133
|
+
def _step(self) -> List[Tensor]:
|
134
|
+
self.b1_t *= self.b1
|
135
|
+
self.b2_t *= self.b2
|
55
136
|
for i, t in enumerate(self.params):
|
56
137
|
assert t.grad is not None
|
57
|
-
|
58
|
-
self.
|
59
|
-
|
60
|
-
|
61
|
-
v_hat = self.v[i] / (1.0 - self.b2**self.t)
|
138
|
+
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
|
139
|
+
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
140
|
+
m_hat = self.m[i] / (1.0 - self.b1_t)
|
141
|
+
v_hat = self.v[i] / (1.0 - self.b2_t)
|
62
142
|
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
63
143
|
if not self.adam:
|
64
144
|
r1 = t.detach().square().sum().sqrt()
|
@@ -66,5 +146,5 @@ class LAMB(Optimizer):
|
|
66
146
|
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
67
147
|
else:
|
68
148
|
r = 1.0
|
69
|
-
t.assign(t.detach() - self.lr * r * up)
|
70
|
-
self.
|
149
|
+
t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
|
150
|
+
return [self.b1_t, self.b2_t] + self.m + self.v
|