tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/nn/__init__.py
CHANGED
@@ -2,9 +2,33 @@ import math
|
|
2
2
|
from typing import Optional, Union, Tuple, cast
|
3
3
|
from tinygrad.tensor import Tensor
|
4
4
|
from tinygrad.helpers import prod
|
5
|
-
from tinygrad.nn import optim, state # noqa: F401
|
5
|
+
from tinygrad.nn import optim, state, datasets # noqa: F401
|
6
6
|
|
7
7
|
class BatchNorm2d:
|
8
|
+
"""
|
9
|
+
Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension).
|
10
|
+
|
11
|
+
- Described: https://paperswithcode.com/method/batch-normalization
|
12
|
+
- Paper: https://arxiv.org/abs/1502.03167v3
|
13
|
+
|
14
|
+
See: `Tensor.batchnorm`
|
15
|
+
|
16
|
+
```python exec="true" session="tensor"
|
17
|
+
from tinygrad import Tensor, dtypes, nn
|
18
|
+
import numpy as np
|
19
|
+
np.set_printoptions(precision=4)
|
20
|
+
```
|
21
|
+
|
22
|
+
```python exec="true" source="above" session="tensor" result="python"
|
23
|
+
norm = nn.BatchNorm2d(3)
|
24
|
+
t = Tensor.rand(2, 3, 4, 4)
|
25
|
+
print(t.mean().item(), t.std().item())
|
26
|
+
```
|
27
|
+
```python exec="true" source="above" session="tensor" result="python"
|
28
|
+
t = norm(t)
|
29
|
+
print(t.mean().item(), t.std().item())
|
30
|
+
```
|
31
|
+
"""
|
8
32
|
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
9
33
|
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
10
34
|
|
@@ -20,7 +44,7 @@ class BatchNorm2d:
|
|
20
44
|
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
21
45
|
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
22
46
|
batch_mean = x.mean(axis=(0,2,3))
|
23
|
-
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
47
|
+
y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
|
24
48
|
batch_var = (y*y).mean(axis=(0,2,3))
|
25
49
|
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
26
50
|
|
@@ -38,9 +62,39 @@ class BatchNorm2d:
|
|
38
62
|
|
39
63
|
# TODO: these Conv lines are terrible
|
40
64
|
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
65
|
+
"""
|
66
|
+
Applies a 1D convolution over an input signal composed of several input planes.
|
67
|
+
|
68
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d
|
69
|
+
|
70
|
+
```python exec="true" source="above" session="tensor" result="python"
|
71
|
+
conv = nn.Conv1d(1, 1, 3)
|
72
|
+
t = Tensor.rand(1, 1, 4)
|
73
|
+
print(t.numpy())
|
74
|
+
```
|
75
|
+
```python exec="true" source="above" session="tensor" result="python"
|
76
|
+
t = conv(t)
|
77
|
+
print(t.numpy())
|
78
|
+
```
|
79
|
+
"""
|
41
80
|
return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
|
42
81
|
|
43
82
|
class Conv2d:
|
83
|
+
"""
|
84
|
+
Applies a 2D convolution over an input signal composed of several input planes.
|
85
|
+
|
86
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d
|
87
|
+
|
88
|
+
```python exec="true" source="above" session="tensor" result="python"
|
89
|
+
conv = nn.Conv2d(1, 1, 3)
|
90
|
+
t = Tensor.rand(1, 1, 4, 4)
|
91
|
+
print(t.numpy())
|
92
|
+
```
|
93
|
+
```python exec="true" source="above" session="tensor" result="python"
|
94
|
+
t = conv(t)
|
95
|
+
print(t.numpy())
|
96
|
+
```
|
97
|
+
"""
|
44
98
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
45
99
|
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
46
100
|
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
@@ -55,9 +109,39 @@ class Conv2d:
|
|
55
109
|
return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
56
110
|
|
57
111
|
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
112
|
+
"""
|
113
|
+
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
|
114
|
+
|
115
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d
|
116
|
+
|
117
|
+
```python exec="true" source="above" session="tensor" result="python"
|
118
|
+
conv = nn.ConvTranspose1d(1, 1, 3)
|
119
|
+
t = Tensor.rand(1, 1, 4)
|
120
|
+
print(t.numpy())
|
121
|
+
```
|
122
|
+
```python exec="true" source="above" session="tensor" result="python"
|
123
|
+
t = conv(t)
|
124
|
+
print(t.numpy())
|
125
|
+
```
|
126
|
+
"""
|
58
127
|
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
|
59
128
|
|
60
129
|
class ConvTranspose2d(Conv2d):
|
130
|
+
"""
|
131
|
+
Applies a 2D transposed convolution operator over an input image.
|
132
|
+
|
133
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d
|
134
|
+
|
135
|
+
```python exec="true" source="above" session="tensor" result="python"
|
136
|
+
conv = nn.ConvTranspose2d(1, 1, 3)
|
137
|
+
t = Tensor.rand(1, 1, 4, 4)
|
138
|
+
print(t.numpy())
|
139
|
+
```
|
140
|
+
```python exec="true" source="above" session="tensor" result="python"
|
141
|
+
t = conv(t)
|
142
|
+
print(t.numpy())
|
143
|
+
```
|
144
|
+
"""
|
61
145
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
62
146
|
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
63
147
|
self.output_padding = output_padding
|
@@ -70,6 +154,21 @@ class ConvTranspose2d(Conv2d):
|
|
70
154
|
return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
71
155
|
|
72
156
|
class Linear:
|
157
|
+
"""
|
158
|
+
Applies a linear transformation to the incoming data.
|
159
|
+
|
160
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Linear
|
161
|
+
|
162
|
+
```python exec="true" source="above" session="tensor" result="python"
|
163
|
+
lin = nn.Linear(3, 4)
|
164
|
+
t = Tensor.rand(2, 3)
|
165
|
+
print(t.numpy())
|
166
|
+
```
|
167
|
+
```python exec="true" source="above" session="tensor" result="python"
|
168
|
+
t = lin(t)
|
169
|
+
print(t.numpy())
|
170
|
+
```
|
171
|
+
"""
|
73
172
|
def __init__(self, in_features, out_features, bias=True):
|
74
173
|
# TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
|
75
174
|
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
@@ -80,6 +179,22 @@ class Linear:
|
|
80
179
|
return x.linear(self.weight.transpose(), self.bias)
|
81
180
|
|
82
181
|
class GroupNorm:
|
182
|
+
"""
|
183
|
+
Applies Group Normalization over a mini-batch of inputs.
|
184
|
+
|
185
|
+
- Described: https://paperswithcode.com/method/group-normalization
|
186
|
+
- Paper: https://arxiv.org/abs/1803.08494v3
|
187
|
+
|
188
|
+
```python exec="true" source="above" session="tensor" result="python"
|
189
|
+
norm = nn.GroupNorm(2, 12)
|
190
|
+
t = Tensor.rand(2, 12, 4, 4) * 2 + 1
|
191
|
+
print(t.mean().item(), t.std().item())
|
192
|
+
```
|
193
|
+
```python exec="true" source="above" session="tensor" result="python"
|
194
|
+
t = norm(t)
|
195
|
+
print(t.mean().item(), t.std().item())
|
196
|
+
```
|
197
|
+
"""
|
83
198
|
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
|
84
199
|
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
85
200
|
self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
@@ -95,6 +210,22 @@ class GroupNorm:
|
|
95
210
|
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
96
211
|
|
97
212
|
class InstanceNorm:
|
213
|
+
"""
|
214
|
+
Applies Instance Normalization over a mini-batch of inputs.
|
215
|
+
|
216
|
+
- Described: https://paperswithcode.com/method/instance-normalization
|
217
|
+
- Paper: https://arxiv.org/abs/1607.08022v3
|
218
|
+
|
219
|
+
```python exec="true" source="above" session="tensor" result="python"
|
220
|
+
norm = nn.InstanceNorm(3)
|
221
|
+
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
|
222
|
+
print(t.mean().item(), t.std().item())
|
223
|
+
```
|
224
|
+
```python exec="true" source="above" session="tensor" result="python"
|
225
|
+
t = norm(t)
|
226
|
+
print(t.mean().item(), t.std().item())
|
227
|
+
```
|
228
|
+
"""
|
98
229
|
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
|
99
230
|
self.num_features, self.eps = num_features, eps
|
100
231
|
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
|
@@ -106,6 +237,22 @@ class InstanceNorm:
|
|
106
237
|
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
107
238
|
|
108
239
|
class LayerNorm:
|
240
|
+
"""
|
241
|
+
Applies Layer Normalization over a mini-batch of inputs.
|
242
|
+
|
243
|
+
- Described: https://paperswithcode.com/method/layer-normalization
|
244
|
+
- Paper: https://arxiv.org/abs/1607.06450v1
|
245
|
+
|
246
|
+
```python exec="true" source="above" session="tensor" result="python"
|
247
|
+
norm = nn.LayerNorm(3)
|
248
|
+
t = Tensor.rand(2, 5, 3) * 2 + 1
|
249
|
+
print(t.mean().item(), t.std().item())
|
250
|
+
```
|
251
|
+
```python exec="true" source="above" session="tensor" result="python"
|
252
|
+
t = norm(t)
|
253
|
+
print(t.mean().item(), t.std().item())
|
254
|
+
```
|
255
|
+
"""
|
109
256
|
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
|
110
257
|
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
111
258
|
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
@@ -118,16 +265,40 @@ class LayerNorm:
|
|
118
265
|
return x * self.weight + self.bias
|
119
266
|
|
120
267
|
class LayerNorm2d(LayerNorm):
|
268
|
+
"""
|
269
|
+
Applies Layer Normalization over a mini-batch of 2D inputs.
|
270
|
+
|
271
|
+
See: `LayerNorm`
|
272
|
+
|
273
|
+
```python exec="true" source="above" session="tensor" result="python"
|
274
|
+
norm = nn.LayerNorm2d(3)
|
275
|
+
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
|
276
|
+
print(t.mean().item(), t.std().item())
|
277
|
+
```
|
278
|
+
```python exec="true" source="above" session="tensor" result="python"
|
279
|
+
t = norm(t)
|
280
|
+
print(t.mean().item(), t.std().item())
|
281
|
+
```
|
282
|
+
"""
|
121
283
|
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
122
284
|
|
123
285
|
class Embedding:
|
286
|
+
"""
|
287
|
+
A simple lookup table that stores embeddings of a fixed dictionary and size.
|
288
|
+
|
289
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Embedding
|
290
|
+
|
291
|
+
```python exec="true" source="above" session="tensor" result="python"
|
292
|
+
emb = nn.Embedding(10, 3)
|
293
|
+
print(emb(Tensor([1, 2, 3, 1])).numpy())
|
294
|
+
```
|
295
|
+
"""
|
124
296
|
def __init__(self, vocab_size:int, embed_size:int):
|
125
|
-
self.
|
126
|
-
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
|
297
|
+
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
|
127
298
|
|
128
299
|
def __call__(self, idx:Tensor) -> Tensor:
|
129
|
-
if
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
return (
|
300
|
+
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
|
301
|
+
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
302
|
+
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
303
|
+
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
|
304
|
+
return (arange == idx).mul(vals).sum(2)
|
tinygrad/nn/datasets.py
ADDED
@@ -0,0 +1,8 @@
|
|
1
|
+
import gzip
|
2
|
+
from tinygrad.tensor import Tensor
|
3
|
+
from tinygrad.helpers import fetch
|
4
|
+
|
5
|
+
def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
|
6
|
+
def mnist():
|
7
|
+
return _fetch_mnist("train-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("train-labels-idx1-ubyte.gz", 8), \
|
8
|
+
_fetch_mnist("t10k-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("t10k-labels-idx1-ubyte.gz", 8)
|
tinygrad/nn/optim.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1
1
|
# sorted in order of increasing complexity
|
2
2
|
from typing import List
|
3
|
-
from tinygrad.helpers import dedup, getenv
|
3
|
+
from tinygrad.helpers import dedup, flatten, getenv
|
4
4
|
from tinygrad.tensor import Tensor
|
5
|
+
from tinygrad.dtype import dtypes, least_upper_dtype
|
5
6
|
|
6
7
|
class Optimizer:
|
8
|
+
"""
|
9
|
+
Base class for all optimizers.
|
10
|
+
"""
|
7
11
|
def __init__(self, params: List[Tensor], lr: float):
|
8
12
|
# if it's None, but being put into an optimizer, set it to True
|
9
13
|
for x in params:
|
@@ -13,54 +17,128 @@ class Optimizer:
|
|
13
17
|
assert len(self.params) != 0, "optimizer must have at least one param"
|
14
18
|
self.device = self.params[0].device
|
15
19
|
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
16
|
-
|
20
|
+
# store lr in at least float32 precision
|
21
|
+
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
|
22
|
+
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
|
17
23
|
|
18
24
|
def zero_grad(self):
|
25
|
+
"""
|
26
|
+
Zeroes the gradients of all the parameters.
|
27
|
+
"""
|
19
28
|
for param in self.params: param.grad = None
|
20
29
|
|
21
|
-
def
|
22
|
-
|
23
|
-
|
30
|
+
def step(self):
|
31
|
+
"""
|
32
|
+
Performs a single optimization step.
|
33
|
+
"""
|
34
|
+
Tensor.realize(*self.schedule_step())
|
35
|
+
def schedule_step(self) -> List[Tensor]:
|
36
|
+
"""
|
37
|
+
Returns the tensors that need to be realized to perform a single optimization step.
|
38
|
+
"""
|
39
|
+
assert Tensor.training, (
|
40
|
+
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
|
41
|
+
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
|
42
|
+
return self._step()+self.params+self.buffers
|
43
|
+
def _step(self) -> List[Tensor]: raise NotImplementedError
|
24
44
|
|
25
|
-
class
|
26
|
-
|
45
|
+
class OptimizerGroup(Optimizer):
|
46
|
+
"""
|
47
|
+
Combines multiple optimizers into one.
|
48
|
+
"""
|
49
|
+
def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called
|
50
|
+
self.optimizers = optimizers
|
51
|
+
self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
|
52
|
+
def __getitem__(self, i): return self.optimizers[i]
|
53
|
+
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
|
54
|
+
def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]
|
55
|
+
|
56
|
+
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
|
57
|
+
def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
|
58
|
+
"""
|
59
|
+
Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
|
60
|
+
|
61
|
+
`classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
|
62
|
+
|
63
|
+
- Described: https://paperswithcode.com/method/sgd
|
64
|
+
"""
|
65
|
+
return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0)
|
66
|
+
|
67
|
+
class LARS(Optimizer):
|
68
|
+
"""
|
69
|
+
Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.
|
70
|
+
|
71
|
+
- Described: https://paperswithcode.com/method/lars
|
72
|
+
- Paper: https://arxiv.org/abs/1708.03888v3
|
73
|
+
"""
|
74
|
+
def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
|
27
75
|
super().__init__(params, lr)
|
28
|
-
self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
|
29
|
-
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
76
|
+
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
|
77
|
+
self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
30
78
|
|
31
|
-
|
32
|
-
def step(self) -> None:
|
79
|
+
def _step(self) -> List[Tensor]:
|
33
80
|
for i, t in enumerate(self.params):
|
34
81
|
assert t.grad is not None
|
35
|
-
#
|
82
|
+
# contiguous is needed since the grads can allegedly form a "diamond"
|
36
83
|
# TODO: fix this in lazy.py
|
37
|
-
t.grad.
|
38
|
-
|
84
|
+
g = t.grad.contiguous()
|
85
|
+
if self.tcoef != 0:
|
86
|
+
r1 = t.detach().square().sum().sqrt()
|
87
|
+
r2 = g.square().sum().sqrt()
|
88
|
+
r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
|
89
|
+
else: r = 1.0
|
90
|
+
g = g + self.wd * t.detach()
|
91
|
+
# classic momentum does post learning rate update
|
92
|
+
if self.classic: g = g * r * self.lr
|
39
93
|
if self.momentum:
|
40
94
|
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
|
41
95
|
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
42
|
-
|
43
|
-
|
96
|
+
# popular momentum does pre learning rate update
|
97
|
+
if not self.classic: g = g * r * self.lr
|
98
|
+
t.assign((t.detach() - g).cast(t.dtype))
|
99
|
+
return self.b
|
44
100
|
|
45
101
|
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
|
46
|
-
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8,
|
47
|
-
|
102
|
+
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
|
103
|
+
"""
|
104
|
+
AdamW optimizer with optional weight decay.
|
105
|
+
|
106
|
+
- Described: https://paperswithcode.com/method/adamw
|
107
|
+
- Paper: https://arxiv.org/abs/1711.05101v3
|
108
|
+
"""
|
109
|
+
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)
|
110
|
+
def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
111
|
+
"""
|
112
|
+
Adam optimizer.
|
113
|
+
|
114
|
+
- Described: https://paperswithcode.com/method/adam
|
115
|
+
- Paper: https://arxiv.org/abs/1412.6980
|
116
|
+
"""
|
117
|
+
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
|
48
118
|
|
49
119
|
class LAMB(Optimizer):
|
50
|
-
|
120
|
+
"""
|
121
|
+
LAMB optimizer with optional weight decay.
|
122
|
+
|
123
|
+
- Described: https://paperswithcode.com/method/lamb
|
124
|
+
- Paper: https://arxiv.org/abs/1904.00962
|
125
|
+
"""
|
126
|
+
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
|
51
127
|
super().__init__(params, lr)
|
52
|
-
self.b1, self.b2, self.eps, self.wd, self.adam
|
53
|
-
self.
|
54
|
-
self.
|
128
|
+
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
|
129
|
+
self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2])
|
130
|
+
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
131
|
+
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
55
132
|
|
56
|
-
def
|
57
|
-
self.
|
133
|
+
def _step(self) -> List[Tensor]:
|
134
|
+
self.b1_t *= self.b1
|
135
|
+
self.b2_t *= self.b2
|
58
136
|
for i, t in enumerate(self.params):
|
59
137
|
assert t.grad is not None
|
60
138
|
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
|
61
139
|
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
62
|
-
m_hat = self.m[i] / (1.0 - self.
|
63
|
-
v_hat = self.v[i] / (1.0 - self.
|
140
|
+
m_hat = self.m[i] / (1.0 - self.b1_t)
|
141
|
+
v_hat = self.v[i] / (1.0 - self.b2_t)
|
64
142
|
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
65
143
|
if not self.adam:
|
66
144
|
r1 = t.detach().square().sum().sqrt()
|
@@ -68,5 +146,5 @@ class LAMB(Optimizer):
|
|
68
146
|
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
69
147
|
else:
|
70
148
|
r = 1.0
|
71
|
-
t.assign(t.detach() - self.lr * r * up)
|
72
|
-
self.
|
149
|
+
t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
|
150
|
+
return [self.b1_t, self.b2_t] + self.m + self.v
|
tinygrad/nn/state.py
CHANGED
@@ -1,32 +1,49 @@
|
|
1
1
|
import os, json, pathlib, zipfile, pickle, tarfile, struct
|
2
|
-
from tqdm import tqdm
|
3
2
|
from typing import Dict, Union, List, Optional, Any, Tuple
|
4
3
|
from tinygrad.tensor import Tensor
|
5
|
-
from tinygrad.ops import GlobalCounters
|
6
4
|
from tinygrad.dtype import dtypes
|
7
|
-
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap
|
5
|
+
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
|
8
6
|
from tinygrad.shape.view import strides_for_shape
|
7
|
+
from tinygrad.multi import MultiLazyBuffer
|
9
8
|
|
10
|
-
safe_dtypes = {"
|
11
|
-
"
|
9
|
+
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
|
10
|
+
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
|
12
11
|
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
13
12
|
|
14
13
|
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
14
|
+
"""
|
15
|
+
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
|
16
|
+
"""
|
15
17
|
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
16
|
-
json_len = t[0:
|
17
|
-
return
|
18
|
+
json_len = t[0:8].bitcast(dtypes.int64).item()
|
19
|
+
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
|
18
20
|
|
19
21
|
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
22
|
+
"""
|
23
|
+
Loads a .safetensor file from disk, returning the state_dict.
|
24
|
+
|
25
|
+
```python
|
26
|
+
state_dict = nn.state.safe_load("test.safetensor")
|
27
|
+
```
|
28
|
+
"""
|
20
29
|
t, json_len, metadata = safe_load_metadata(fn)
|
21
30
|
ret = {}
|
22
31
|
for k,v in metadata.items():
|
23
32
|
if k == "__metadata__": continue
|
24
33
|
dtype = safe_dtypes[v['dtype']]
|
25
|
-
sz = (v['data_offsets'][1]-v['data_offsets'][0])
|
26
|
-
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].
|
34
|
+
sz = (v['data_offsets'][1]-v['data_offsets'][0])
|
35
|
+
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
|
27
36
|
return ret
|
28
37
|
|
29
38
|
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
39
|
+
"""
|
40
|
+
Saves a state_dict to disk in a .safetensor file with optional metadata.
|
41
|
+
|
42
|
+
```python
|
43
|
+
t = Tensor([1, 2, 3])
|
44
|
+
nn.state.safe_save({'t':t}, "test.safetensor")
|
45
|
+
```
|
46
|
+
"""
|
30
47
|
headers, offset = {}, 0
|
31
48
|
if metadata: headers['__metadata__'] = metadata
|
32
49
|
for k,v in tensors.items():
|
@@ -36,14 +53,27 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
|
|
36
53
|
j += "\x20"*((8-len(j)%8)%8)
|
37
54
|
pathlib.Path(fn).unlink(missing_ok=True)
|
38
55
|
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
39
|
-
t[0:
|
40
|
-
t[8:8+len(j)].assign(
|
56
|
+
t[0:8].bitcast(dtypes.int64).assign([len(j)])
|
57
|
+
t[8:8+len(j)].assign(list(j.encode('utf-8')))
|
41
58
|
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
42
59
|
|
43
60
|
# state dict
|
44
61
|
|
45
62
|
from collections import OrderedDict
|
46
63
|
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
64
|
+
"""
|
65
|
+
Returns a state_dict of the object, with optional prefix.
|
66
|
+
|
67
|
+
```python exec="true" source="above" session="tensor" result="python"
|
68
|
+
class Net:
|
69
|
+
def __init__(self):
|
70
|
+
self.l1 = nn.Linear(4, 5)
|
71
|
+
self.l2 = nn.Linear(5, 6)
|
72
|
+
|
73
|
+
net = Net()
|
74
|
+
print(nn.state.get_state_dict(net).keys())
|
75
|
+
```
|
76
|
+
"""
|
47
77
|
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
48
78
|
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
|
49
79
|
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
@@ -54,24 +84,61 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
|
54
84
|
elif isinstance(obj, dict):
|
55
85
|
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
56
86
|
return state_dict
|
57
|
-
def get_parameters(obj) -> List[Tensor]:
|
87
|
+
def get_parameters(obj) -> List[Tensor]:
|
88
|
+
"""
|
89
|
+
```python exec="true" source="above" session="tensor" result="python"
|
90
|
+
class Net:
|
91
|
+
def __init__(self):
|
92
|
+
self.l1 = nn.Linear(4, 5)
|
93
|
+
self.l2 = nn.Linear(5, 6)
|
58
94
|
|
59
|
-
|
95
|
+
net = Net()
|
96
|
+
print(len(nn.state.get_parameters(net)))
|
97
|
+
```
|
98
|
+
"""
|
99
|
+
return list(get_state_dict(obj).values())
|
100
|
+
|
101
|
+
def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
|
102
|
+
"""
|
103
|
+
Loads a state_dict into a model.
|
104
|
+
|
105
|
+
```python
|
106
|
+
class Net:
|
107
|
+
def __init__(self):
|
108
|
+
self.l1 = nn.Linear(4, 5)
|
109
|
+
self.l2 = nn.Linear(5, 6)
|
110
|
+
|
111
|
+
net = Net()
|
112
|
+
state_dict = nn.state.get_state_dict(net)
|
113
|
+
nn.state.load_state_dict(net, state_dict)
|
114
|
+
```
|
115
|
+
"""
|
60
116
|
start_mem_used = GlobalCounters.mem_used
|
61
117
|
with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501
|
62
118
|
model_state_dict = get_state_dict(model)
|
63
119
|
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
64
120
|
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
65
121
|
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
|
66
|
-
t.
|
122
|
+
t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
|
67
123
|
if k not in state_dict and not strict:
|
68
124
|
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
69
125
|
continue
|
70
|
-
|
126
|
+
if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
|
127
|
+
if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
|
128
|
+
else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
|
129
|
+
else: v.replace(state_dict[k].to(v.device)).realize()
|
130
|
+
if consume: del state_dict[k]
|
71
131
|
|
72
132
|
# torch support!
|
73
133
|
|
74
134
|
def torch_load(fn:str) -> Dict[str, Tensor]:
|
135
|
+
"""
|
136
|
+
Loads a torch .pth file from disk.
|
137
|
+
|
138
|
+
```python
|
139
|
+
state_dict = nn.state.torch_load("test.pth")
|
140
|
+
```
|
141
|
+
"""
|
75
142
|
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
76
143
|
|
77
144
|
offsets: Dict[Union[str, int], int] = {}
|
@@ -81,7 +148,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
81
148
|
lens[storage[2]] = storage[4] * storage[1].itemsize
|
82
149
|
if storage[2] not in offsets: return None
|
83
150
|
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
84
|
-
ret = t[byte_offset:byte_offset+prod(size)].
|
151
|
+
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
|
85
152
|
|
86
153
|
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
87
154
|
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
@@ -89,10 +156,11 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
89
156
|
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
90
157
|
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
91
158
|
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
92
|
-
if DEBUG >= 3: print(f"WARNING: this torch load is slow.
|
93
|
-
assert storage[1] != dtypes.bfloat16, "can't
|
159
|
+
if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
|
160
|
+
assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
|
94
161
|
# TODO: find a nice way to support all shapetracker on disktensors
|
95
|
-
|
162
|
+
# TODO: BUG: a ".realize()" is needed here for 'GPU=1 python3 test/models/test_efficientnet.py TestEfficientNet.test_car'
|
163
|
+
ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes).realize()
|
96
164
|
|
97
165
|
return ret.reshape(size)
|
98
166
|
|