tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/nn/__init__.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1
|
+
from __future__ import annotations
|
1
2
|
import math
|
2
|
-
from typing import Optional, Union, Tuple,
|
3
|
-
from tinygrad.tensor import Tensor
|
4
|
-
from tinygrad.
|
3
|
+
from typing import Optional, Union, Tuple, List
|
4
|
+
from tinygrad.tensor import Tensor, dtypes
|
5
|
+
from tinygrad.device import is_dtype_supported
|
6
|
+
from tinygrad.helpers import prod, make_tuple, flatten
|
5
7
|
from tinygrad.nn import optim, state, datasets # noqa: F401
|
6
8
|
|
7
|
-
class
|
9
|
+
class BatchNorm:
|
8
10
|
"""
|
9
|
-
Applies Batch Normalization over a
|
11
|
+
Applies Batch Normalization over a 2D or 3D input.
|
10
12
|
|
11
13
|
- Described: https://paperswithcode.com/method/batch-normalization
|
12
14
|
- Paper: https://arxiv.org/abs/1502.03167v3
|
@@ -20,7 +22,7 @@ class BatchNorm2d:
|
|
20
22
|
```
|
21
23
|
|
22
24
|
```python exec="true" source="above" session="tensor" result="python"
|
23
|
-
norm = nn.
|
25
|
+
norm = nn.BatchNorm(3)
|
24
26
|
t = Tensor.rand(2, 3, 4, 4)
|
25
27
|
print(t.mean().item(), t.std().item())
|
26
28
|
```
|
@@ -32,36 +34,34 @@ class BatchNorm2d:
|
|
32
34
|
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
33
35
|
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
34
36
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
self.
|
39
|
-
self.
|
40
|
-
|
41
|
-
def
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
# TODO: these Conv lines are terrible
|
64
|
-
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
37
|
+
self.weight: Optional[Tensor] = Tensor.ones(sz) if affine else None
|
38
|
+
self.bias: Optional[Tensor] = Tensor.zeros(sz) if affine else None
|
39
|
+
|
40
|
+
self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
|
41
|
+
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
42
|
+
|
43
|
+
def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]:
|
44
|
+
shape_mask: List[int] = [1, -1, *([1]*(x.ndim-2))]
|
45
|
+
if self.track_running_stats and not Tensor.training: return self.running_mean, self.running_var.reshape(shape=shape_mask).expand(x.shape)
|
46
|
+
# This requires two full memory accesses to x
|
47
|
+
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
48
|
+
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
49
|
+
batch_mean = x.mean(axis=(reduce_axes:=tuple(x for x in range(x.ndim) if x != 1)))
|
50
|
+
y = (x - batch_mean.detach().reshape(shape=shape_mask)) # d(var)/d(mean) = 0
|
51
|
+
batch_var = (y*y).mean(axis=reduce_axes)
|
52
|
+
return batch_mean, batch_var
|
53
|
+
|
54
|
+
def __call__(self, x:Tensor) -> Tensor:
|
55
|
+
batch_mean, batch_var = self.calc_stats(x)
|
56
|
+
# NOTE: wow, this is done all throughout training in most PyTorch models
|
57
|
+
if self.track_running_stats and Tensor.training:
|
58
|
+
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
|
59
|
+
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(x.shape)/(prod(x.shape)-x.shape[1]) * batch_var.detach())
|
60
|
+
self.num_batches_tracked += 1
|
61
|
+
return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
|
62
|
+
BatchNorm2d = BatchNorm3d = BatchNorm
|
63
|
+
|
64
|
+
def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:Union[int, str]=0, dilation=1, groups=1, bias=True) -> Conv2d:
|
65
65
|
"""
|
66
66
|
Applies a 1D convolution over an input signal composed of several input planes.
|
67
67
|
|
@@ -95,20 +95,24 @@ class Conv2d:
|
|
95
95
|
print(t.numpy())
|
96
96
|
```
|
97
97
|
"""
|
98
|
-
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding
|
99
|
-
|
100
|
-
self.
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
98
|
+
def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding:Union[int, Tuple[int, ...], str]=0,
|
99
|
+
dilation=1, groups=1, bias=True):
|
100
|
+
self.kernel_size = make_tuple(kernel_size, 2)
|
101
|
+
if isinstance(padding, str):
|
102
|
+
if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
|
103
|
+
if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
|
104
|
+
pad = [(d*(k-1)//2, d*(k-1) - d*(k-1)//2) for d,k in zip(make_tuple(dilation, len(self.kernel_size)), self.kernel_size[::-1])]
|
105
|
+
padding = tuple(flatten(pad))
|
106
|
+
self.stride, self.dilation, self.groups, self.padding = stride, dilation, groups, padding
|
107
|
+
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
108
|
+
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
109
|
+
self.bias: Optional[Tensor] = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
|
110
|
+
|
111
|
+
def __call__(self, x:Tensor) -> Tensor:
|
106
112
|
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
107
113
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
114
|
+
def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
|
115
|
+
groups=1, bias=True) -> ConvTranspose2d:
|
112
116
|
"""
|
113
117
|
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
|
114
118
|
|
@@ -142,17 +146,17 @@ class ConvTranspose2d(Conv2d):
|
|
142
146
|
print(t.numpy())
|
143
147
|
```
|
144
148
|
"""
|
145
|
-
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0,
|
149
|
+
def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding=0, output_padding=0,
|
150
|
+
dilation=1, groups=1, bias=True):
|
146
151
|
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
152
|
+
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
153
|
+
self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
147
154
|
self.output_padding = output_padding
|
148
155
|
|
149
|
-
def __call__(self, x:Tensor):
|
156
|
+
def __call__(self, x:Tensor) -> Tensor:
|
150
157
|
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
|
151
158
|
dilation=self.dilation, groups=self.groups)
|
152
159
|
|
153
|
-
def initialize_weight(self, out_channels, in_channels, groups):
|
154
|
-
return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
155
|
-
|
156
160
|
class Linear:
|
157
161
|
"""
|
158
162
|
Applies a linear transformation to the incoming data.
|
@@ -169,13 +173,12 @@ class Linear:
|
|
169
173
|
print(t.numpy())
|
170
174
|
```
|
171
175
|
"""
|
172
|
-
def __init__(self, in_features, out_features, bias=True):
|
173
|
-
# TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
|
174
|
-
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
176
|
+
def __init__(self, in_features:int, out_features:int, bias=True):
|
175
177
|
bound = 1 / math.sqrt(in_features)
|
178
|
+
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
|
176
179
|
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
177
180
|
|
178
|
-
def __call__(self, x:Tensor):
|
181
|
+
def __call__(self, x:Tensor) -> Tensor:
|
179
182
|
return x.linear(self.weight.transpose(), self.bias)
|
180
183
|
|
181
184
|
class GroupNorm:
|
@@ -195,12 +198,12 @@ class GroupNorm:
|
|
195
198
|
print(t.mean().item(), t.std().item())
|
196
199
|
```
|
197
200
|
"""
|
198
|
-
def __init__(self, num_groups:int, num_channels:int, eps
|
201
|
+
def __init__(self, num_groups:int, num_channels:int, eps=1e-5, affine=True):
|
199
202
|
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
200
203
|
self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
201
204
|
self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
202
205
|
|
203
|
-
def __call__(self, x:Tensor):
|
206
|
+
def __call__(self, x:Tensor) -> Tensor:
|
204
207
|
# reshape for layernorm to work as group norm
|
205
208
|
# subtract mean and divide stddev
|
206
209
|
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
|
@@ -226,12 +229,12 @@ class InstanceNorm:
|
|
226
229
|
print(t.mean().item(), t.std().item())
|
227
230
|
```
|
228
231
|
"""
|
229
|
-
def __init__(self, num_features:int, eps
|
232
|
+
def __init__(self, num_features:int, eps=1e-5, affine=True):
|
230
233
|
self.num_features, self.eps = num_features, eps
|
231
234
|
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
|
232
235
|
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
|
233
236
|
|
234
|
-
def __call__(self, x:Tensor):
|
237
|
+
def __call__(self, x:Tensor) -> Tensor:
|
235
238
|
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
|
236
239
|
if self.weight is None or self.bias is None: return x
|
237
240
|
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
@@ -253,12 +256,12 @@ class LayerNorm:
|
|
253
256
|
print(t.mean().item(), t.std().item())
|
254
257
|
```
|
255
258
|
"""
|
256
|
-
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps
|
257
|
-
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
259
|
+
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
|
260
|
+
self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
258
261
|
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
259
262
|
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
|
260
263
|
|
261
|
-
def __call__(self, x:Tensor):
|
264
|
+
def __call__(self, x:Tensor) -> Tensor:
|
262
265
|
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
263
266
|
x = x.layernorm(eps=self.eps, axis=self.axis)
|
264
267
|
if not self.elementwise_affine: return x
|
@@ -280,7 +283,29 @@ class LayerNorm2d(LayerNorm):
|
|
280
283
|
print(t.mean().item(), t.std().item())
|
281
284
|
```
|
282
285
|
"""
|
283
|
-
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
286
|
+
def __call__(self, x: Tensor) -> Tensor: return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
287
|
+
|
288
|
+
class RMSNorm:
|
289
|
+
"""
|
290
|
+
Applies Root Mean Square Normalization to input.
|
291
|
+
|
292
|
+
- Described: https://paperswithcode.com/method/rmsnorm
|
293
|
+
- Paper: https://arxiv.org/abs/1910.07467
|
294
|
+
|
295
|
+
```python exec="true" source="above" session="tensor" result="python"
|
296
|
+
norm = nn.RMSNorm(4)
|
297
|
+
t = Tensor.arange(12, dtype=dtypes.float).reshape(3, 4)
|
298
|
+
print(t.numpy())
|
299
|
+
```
|
300
|
+
```python exec="true" source="above" session="tensor" result="python"
|
301
|
+
print(norm(t).numpy())
|
302
|
+
```
|
303
|
+
"""
|
304
|
+
def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
|
305
|
+
|
306
|
+
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
307
|
+
|
308
|
+
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
|
284
309
|
|
285
310
|
class Embedding:
|
286
311
|
"""
|
@@ -298,7 +323,31 @@ class Embedding:
|
|
298
323
|
|
299
324
|
def __call__(self, idx:Tensor) -> Tensor:
|
300
325
|
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
|
301
|
-
arange_shp, weight_shp, big_shp = (
|
326
|
+
arange_shp, weight_shp, big_shp = (self.vocab_sz, 1), (self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
302
327
|
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
303
328
|
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
|
304
|
-
return (arange == idx).mul(vals).sum(2)
|
329
|
+
return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
|
330
|
+
|
331
|
+
class LSTMCell:
|
332
|
+
"""
|
333
|
+
A long short-term memory (LSTM) cell.
|
334
|
+
|
335
|
+
Args:
|
336
|
+
input_size: The number of expected features in the input `x`
|
337
|
+
hidden_size: The number of features in the hidden state `h`
|
338
|
+
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`
|
339
|
+
"""
|
340
|
+
def __init__(self, input_size:int, hidden_size:int, bias:bool=True):
|
341
|
+
stdv = 1.0 / math.sqrt(hidden_size)
|
342
|
+
self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
|
343
|
+
self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
|
344
|
+
self.bias_ih, self.bias_hh = (Tensor.zeros(hidden_size*4), Tensor.zeros(hidden_size*4)) if bias else (None, None)
|
345
|
+
|
346
|
+
def __call__(self, x:Tensor, hc:Optional[Tuple[Tensor, Tensor]]=None) -> Tuple[Tensor, Tensor]:
|
347
|
+
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
|
348
|
+
gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
|
349
|
+
i, f, g, o = gates.chunk(4, dim=1)
|
350
|
+
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
351
|
+
new_c = f * hc[1] + i * g
|
352
|
+
new_h = o * new_c.tanh()
|
353
|
+
return (new_h.contiguous(), new_c.contiguous())
|
tinygrad/nn/datasets.py
CHANGED
@@ -1,8 +1,15 @@
|
|
1
|
-
import gzip
|
2
1
|
from tinygrad.tensor import Tensor
|
3
2
|
from tinygrad.helpers import fetch
|
3
|
+
from tinygrad.nn.state import tar_extract
|
4
4
|
|
5
|
-
def
|
6
|
-
|
7
|
-
|
8
|
-
|
5
|
+
def mnist(device=None, fashion=False):
|
6
|
+
base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
7
|
+
def _mnist(file): return Tensor.from_url(base_url+file, gunzip=True)
|
8
|
+
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
|
9
|
+
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
|
10
|
+
|
11
|
+
def cifar(device=None):
|
12
|
+
tt = tar_extract(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
|
13
|
+
train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
|
14
|
+
test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
|
15
|
+
return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]
|
tinygrad/nn/optim.py
CHANGED
@@ -126,7 +126,7 @@ class LAMB(Optimizer):
|
|
126
126
|
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
|
127
127
|
super().__init__(params, lr)
|
128
128
|
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
|
129
|
-
self.b1_t, self.b2_t = (Tensor(
|
129
|
+
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
|
130
130
|
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
131
131
|
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
132
132
|
|
tinygrad/nn/state.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
|
-
import os, json, pathlib, zipfile, pickle, tarfile, struct
|
2
|
-
from typing import Dict, Union, List, Optional, Any, Tuple
|
1
|
+
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools
|
2
|
+
from typing import Dict, Union, List, Optional, Any, Tuple, Callable
|
3
3
|
from tinygrad.tensor import Tensor
|
4
4
|
from tinygrad.dtype import dtypes
|
5
5
|
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
|
@@ -16,7 +16,7 @@ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
|
16
16
|
"""
|
17
17
|
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
18
18
|
json_len = t[0:8].bitcast(dtypes.int64).item()
|
19
|
-
return t, json_len, json.loads(t[8:8+json_len].
|
19
|
+
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
|
20
20
|
|
21
21
|
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
22
22
|
"""
|
@@ -129,6 +129,18 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
|
129
129
|
else: v.replace(state_dict[k].to(v.device)).realize()
|
130
130
|
if consume: del state_dict[k]
|
131
131
|
|
132
|
+
def tar_extract(fn:os.PathLike) -> Dict[str, Tensor]:
|
133
|
+
"""
|
134
|
+
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
|
135
|
+
|
136
|
+
```python
|
137
|
+
tensors = nn.state.tar_extract("archive.tar")
|
138
|
+
```
|
139
|
+
"""
|
140
|
+
t = Tensor(pathlib.Path(fn))
|
141
|
+
with tarfile.open(fn, "r") as tar:
|
142
|
+
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
|
143
|
+
|
132
144
|
# torch support!
|
133
145
|
|
134
146
|
def torch_load(fn:str) -> Dict[str, Tensor]:
|
@@ -159,8 +171,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
159
171
|
if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
|
160
172
|
assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
|
161
173
|
# TODO: find a nice way to support all shapetracker on disktensors
|
162
|
-
|
163
|
-
ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes).realize()
|
174
|
+
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
|
164
175
|
|
165
176
|
return ret.reshape(size)
|
166
177
|
|
@@ -168,7 +179,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
168
179
|
def __setstate__(self, state): self.tensor = state[0]
|
169
180
|
|
170
181
|
deserialized_objects: Dict[str, Any] = {}
|
171
|
-
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
|
182
|
+
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
|
183
|
+
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
|
172
184
|
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
173
185
|
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
174
186
|
class Dummy: pass
|
@@ -214,3 +226,76 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
214
226
|
base_offset += 8 + lens[i]
|
215
227
|
f.seek(rwd)
|
216
228
|
return TorchPickle(f).load()
|
229
|
+
|
230
|
+
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
231
|
+
"""
|
232
|
+
Converts ggml tensor data to a tinygrad tensor.
|
233
|
+
|
234
|
+
Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
|
235
|
+
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14)
|
236
|
+
"""
|
237
|
+
# https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356
|
238
|
+
|
239
|
+
# native types
|
240
|
+
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
|
241
|
+
return t[:dtype.itemsize * n].bitcast(dtype)
|
242
|
+
|
243
|
+
def q_to_uint8(t: Tensor, b: int) -> Tensor:
|
244
|
+
# TODO: rewrite with arange?
|
245
|
+
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
|
246
|
+
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
|
247
|
+
|
248
|
+
# map to (number of elements, number of bytes)
|
249
|
+
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None:
|
250
|
+
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
|
251
|
+
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
|
252
|
+
if ggml_type == 3:
|
253
|
+
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
|
254
|
+
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
|
255
|
+
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
|
256
|
+
if ggml_type == 14:
|
257
|
+
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
|
258
|
+
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
|
259
|
+
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
|
260
|
+
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
|
261
|
+
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
|
262
|
+
|
263
|
+
def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
|
264
|
+
"""
|
265
|
+
Loads a gguf file from a tensor.
|
266
|
+
|
267
|
+
```python
|
268
|
+
fn = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
|
269
|
+
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
|
270
|
+
kv_data, state_dict = gguf_load(gguf_tensor)
|
271
|
+
```
|
272
|
+
"""
|
273
|
+
if tensor.dtype != dtypes.uint8 or len(tensor.shape) != 1: raise ValueError("GGUF tensor must be 1d and of dtype uint8!")
|
274
|
+
pos, read_buffer, rb_start, kv_data, state_dict = 0, memoryview(bytes()), 0, {}, {}
|
275
|
+
def read_bytes(n: int):
|
276
|
+
nonlocal pos, read_buffer, rb_start
|
277
|
+
if rb_start + len(read_buffer) < pos + n: rb_start, read_buffer = pos, tensor[pos:(pos+max(n, 1000_000))].data()
|
278
|
+
return read_buffer[pos-rb_start:(pos:=pos+n)-rb_start]
|
279
|
+
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, read_bytes(n))[0]
|
280
|
+
def read_str(): return str(read_bytes(read_uint64()), "utf-8")
|
281
|
+
def read_arr():
|
282
|
+
reader, n = readers[read_int32()], read_uint64()
|
283
|
+
return [ reader() for _ in range(n) ]
|
284
|
+
|
285
|
+
readers: Dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t, f, nb in [ (0,"c",1),
|
286
|
+
(1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
|
287
|
+
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
|
288
|
+
|
289
|
+
magic, version, n_tensors, n_kv = read_bytes(4), read_int32(), read_int64(), read_int64()
|
290
|
+
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
|
291
|
+
for _ in range(n_kv):
|
292
|
+
k, typ = read_str(), read_int32()
|
293
|
+
kv_data[k] = readers[typ]()
|
294
|
+
|
295
|
+
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
|
296
|
+
alignment = kv_data.get("general.alignment", 32)
|
297
|
+
data_start = pos = pos + (alignment - pos % alignment if pos % alignment != 0 else 0)
|
298
|
+
|
299
|
+
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
|
300
|
+
|
301
|
+
return kv_data, state_dict
|