tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/nn/__init__.py CHANGED
@@ -1,12 +1,14 @@
1
+ from __future__ import annotations
1
2
  import math
2
- from typing import Optional, Union, Tuple, cast
3
- from tinygrad.tensor import Tensor
4
- from tinygrad.helpers import prod
3
+ from typing import Optional, Union, Tuple, List
4
+ from tinygrad.tensor import Tensor, dtypes
5
+ from tinygrad.device import is_dtype_supported
6
+ from tinygrad.helpers import prod, make_tuple, flatten
5
7
  from tinygrad.nn import optim, state, datasets # noqa: F401
6
8
 
7
- class BatchNorm2d:
9
+ class BatchNorm:
8
10
  """
9
- Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension).
11
+ Applies Batch Normalization over a 2D or 3D input.
10
12
 
11
13
  - Described: https://paperswithcode.com/method/batch-normalization
12
14
  - Paper: https://arxiv.org/abs/1502.03167v3
@@ -20,7 +22,7 @@ class BatchNorm2d:
20
22
  ```
21
23
 
22
24
  ```python exec="true" source="above" session="tensor" result="python"
23
- norm = nn.BatchNorm2d(3)
25
+ norm = nn.BatchNorm(3)
24
26
  t = Tensor.rand(2, 3, 4, 4)
25
27
  print(t.mean().item(), t.std().item())
26
28
  ```
@@ -32,36 +34,34 @@ class BatchNorm2d:
32
34
  def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
33
35
  self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
34
36
 
35
- if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
36
- else: self.weight, self.bias = None, None
37
-
38
- self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
39
- self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
40
-
41
- def __call__(self, x:Tensor):
42
- if Tensor.training:
43
- # This requires two full memory accesses to x
44
- # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
45
- # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
46
- batch_mean = x.mean(axis=(0,2,3))
47
- y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
48
- batch_var = (y*y).mean(axis=(0,2,3))
49
- batch_invstd = batch_var.add(self.eps).pow(-0.5)
50
-
51
- # NOTE: wow, this is done all throughout training in most PyTorch models
52
- if self.track_running_stats:
53
- self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
54
- self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape)-y.shape[1]) * batch_var.detach())
55
- self.num_batches_tracked += 1
56
- else:
57
- batch_mean = self.running_mean
58
- # NOTE: this can be precomputed for static inference. we expand it here so it fuses
59
- batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
60
-
61
- return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
62
-
63
- # TODO: these Conv lines are terrible
64
- def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
37
+ self.weight: Optional[Tensor] = Tensor.ones(sz) if affine else None
38
+ self.bias: Optional[Tensor] = Tensor.zeros(sz) if affine else None
39
+
40
+ self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
41
+ if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
42
+
43
+ def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]:
44
+ shape_mask: List[int] = [1, -1, *([1]*(x.ndim-2))]
45
+ if self.track_running_stats and not Tensor.training: return self.running_mean, self.running_var.reshape(shape=shape_mask).expand(x.shape)
46
+ # This requires two full memory accesses to x
47
+ # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
48
+ # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
49
+ batch_mean = x.mean(axis=(reduce_axes:=tuple(x for x in range(x.ndim) if x != 1)))
50
+ y = (x - batch_mean.detach().reshape(shape=shape_mask)) # d(var)/d(mean) = 0
51
+ batch_var = (y*y).mean(axis=reduce_axes)
52
+ return batch_mean, batch_var
53
+
54
+ def __call__(self, x:Tensor) -> Tensor:
55
+ batch_mean, batch_var = self.calc_stats(x)
56
+ # NOTE: wow, this is done all throughout training in most PyTorch models
57
+ if self.track_running_stats and Tensor.training:
58
+ self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
59
+ self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(x.shape)/(prod(x.shape)-x.shape[1]) * batch_var.detach())
60
+ self.num_batches_tracked += 1
61
+ return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
62
+ BatchNorm2d = BatchNorm3d = BatchNorm
63
+
64
+ def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:Union[int, str]=0, dilation=1, groups=1, bias=True) -> Conv2d:
65
65
  """
66
66
  Applies a 1D convolution over an input signal composed of several input planes.
67
67
 
@@ -95,20 +95,24 @@ class Conv2d:
95
95
  print(t.numpy())
96
96
  ```
97
97
  """
98
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
99
- self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
100
- self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
101
- self.weight = self.initialize_weight(out_channels, in_channels, groups)
102
- bound = 1 / math.sqrt(cast(int, prod(self.weight.shape[1:]))) # weight shape is always ints but mypy cannot tell
103
- self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
104
-
105
- def __call__(self, x:Tensor):
98
+ def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding:Union[int, Tuple[int, ...], str]=0,
99
+ dilation=1, groups=1, bias=True):
100
+ self.kernel_size = make_tuple(kernel_size, 2)
101
+ if isinstance(padding, str):
102
+ if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
103
+ if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
104
+ pad = [(d*(k-1)//2, d*(k-1) - d*(k-1)//2) for d,k in zip(make_tuple(dilation, len(self.kernel_size)), self.kernel_size[::-1])]
105
+ padding = tuple(flatten(pad))
106
+ self.stride, self.dilation, self.groups, self.padding = stride, dilation, groups, padding
107
+ scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
108
+ self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
109
+ self.bias: Optional[Tensor] = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
110
+
111
+ def __call__(self, x:Tensor) -> Tensor:
106
112
  return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
107
113
 
108
- def initialize_weight(self, out_channels, in_channels, groups):
109
- return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
110
-
111
- def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
114
+ def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
115
+ groups=1, bias=True) -> ConvTranspose2d:
112
116
  """
113
117
  Applies a 1D transposed convolution operator over an input signal composed of several input planes.
114
118
 
@@ -142,17 +146,17 @@ class ConvTranspose2d(Conv2d):
142
146
  print(t.numpy())
143
147
  ```
144
148
  """
145
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
149
+ def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding=0, output_padding=0,
150
+ dilation=1, groups=1, bias=True):
146
151
  super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
152
+ scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
153
+ self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
147
154
  self.output_padding = output_padding
148
155
 
149
- def __call__(self, x:Tensor):
156
+ def __call__(self, x:Tensor) -> Tensor:
150
157
  return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
151
158
  dilation=self.dilation, groups=self.groups)
152
159
 
153
- def initialize_weight(self, out_channels, in_channels, groups):
154
- return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
155
-
156
160
  class Linear:
157
161
  """
158
162
  Applies a linear transformation to the incoming data.
@@ -169,13 +173,12 @@ class Linear:
169
173
  print(t.numpy())
170
174
  ```
171
175
  """
172
- def __init__(self, in_features, out_features, bias=True):
173
- # TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
174
- self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
176
+ def __init__(self, in_features:int, out_features:int, bias=True):
175
177
  bound = 1 / math.sqrt(in_features)
178
+ self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
176
179
  self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
177
180
 
178
- def __call__(self, x:Tensor):
181
+ def __call__(self, x:Tensor) -> Tensor:
179
182
  return x.linear(self.weight.transpose(), self.bias)
180
183
 
181
184
  class GroupNorm:
@@ -195,12 +198,12 @@ class GroupNorm:
195
198
  print(t.mean().item(), t.std().item())
196
199
  ```
197
200
  """
198
- def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
201
+ def __init__(self, num_groups:int, num_channels:int, eps=1e-5, affine=True):
199
202
  self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
200
203
  self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
201
204
  self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
202
205
 
203
- def __call__(self, x:Tensor):
206
+ def __call__(self, x:Tensor) -> Tensor:
204
207
  # reshape for layernorm to work as group norm
205
208
  # subtract mean and divide stddev
206
209
  x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
@@ -226,12 +229,12 @@ class InstanceNorm:
226
229
  print(t.mean().item(), t.std().item())
227
230
  ```
228
231
  """
229
- def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
232
+ def __init__(self, num_features:int, eps=1e-5, affine=True):
230
233
  self.num_features, self.eps = num_features, eps
231
234
  self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
232
235
  self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
233
236
 
234
- def __call__(self, x:Tensor):
237
+ def __call__(self, x:Tensor) -> Tensor:
235
238
  x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
236
239
  if self.weight is None or self.bias is None: return x
237
240
  return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
@@ -253,12 +256,12 @@ class LayerNorm:
253
256
  print(t.mean().item(), t.std().item())
254
257
  ```
255
258
  """
256
- def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
257
- self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
259
+ def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
260
+ self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
258
261
  self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
259
262
  self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
260
263
 
261
- def __call__(self, x:Tensor):
264
+ def __call__(self, x:Tensor) -> Tensor:
262
265
  assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
263
266
  x = x.layernorm(eps=self.eps, axis=self.axis)
264
267
  if not self.elementwise_affine: return x
@@ -280,7 +283,29 @@ class LayerNorm2d(LayerNorm):
280
283
  print(t.mean().item(), t.std().item())
281
284
  ```
282
285
  """
283
- def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
286
+ def __call__(self, x: Tensor) -> Tensor: return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
287
+
288
+ class RMSNorm:
289
+ """
290
+ Applies Root Mean Square Normalization to input.
291
+
292
+ - Described: https://paperswithcode.com/method/rmsnorm
293
+ - Paper: https://arxiv.org/abs/1910.07467
294
+
295
+ ```python exec="true" source="above" session="tensor" result="python"
296
+ norm = nn.RMSNorm(4)
297
+ t = Tensor.arange(12, dtype=dtypes.float).reshape(3, 4)
298
+ print(t.numpy())
299
+ ```
300
+ ```python exec="true" source="above" session="tensor" result="python"
301
+ print(norm(t).numpy())
302
+ ```
303
+ """
304
+ def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
305
+
306
+ def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
307
+
308
+ def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
284
309
 
285
310
  class Embedding:
286
311
  """
@@ -298,7 +323,31 @@ class Embedding:
298
323
 
299
324
  def __call__(self, idx:Tensor) -> Tensor:
300
325
  if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
301
- arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
326
+ arange_shp, weight_shp, big_shp = (self.vocab_sz, 1), (self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
302
327
  if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
303
328
  arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
304
- return (arange == idx).mul(vals).sum(2)
329
+ return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
330
+
331
+ class LSTMCell:
332
+ """
333
+ A long short-term memory (LSTM) cell.
334
+
335
+ Args:
336
+ input_size: The number of expected features in the input `x`
337
+ hidden_size: The number of features in the hidden state `h`
338
+ bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`
339
+ """
340
+ def __init__(self, input_size:int, hidden_size:int, bias:bool=True):
341
+ stdv = 1.0 / math.sqrt(hidden_size)
342
+ self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
343
+ self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
344
+ self.bias_ih, self.bias_hh = (Tensor.zeros(hidden_size*4), Tensor.zeros(hidden_size*4)) if bias else (None, None)
345
+
346
+ def __call__(self, x:Tensor, hc:Optional[Tuple[Tensor, Tensor]]=None) -> Tuple[Tensor, Tensor]:
347
+ if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
348
+ gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
349
+ i, f, g, o = gates.chunk(4, dim=1)
350
+ i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
351
+ new_c = f * hc[1] + i * g
352
+ new_h = o * new_c.tanh()
353
+ return (new_h.contiguous(), new_c.contiguous())
tinygrad/nn/datasets.py CHANGED
@@ -1,8 +1,15 @@
1
- import gzip
2
1
  from tinygrad.tensor import Tensor
3
2
  from tinygrad.helpers import fetch
3
+ from tinygrad.nn.state import tar_extract
4
4
 
5
- def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
6
- def mnist():
7
- return _fetch_mnist("train-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("train-labels-idx1-ubyte.gz", 8), \
8
- _fetch_mnist("t10k-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("t10k-labels-idx1-ubyte.gz", 8)
5
+ def mnist(device=None, fashion=False):
6
+ base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/"
7
+ def _mnist(file): return Tensor.from_url(base_url+file, gunzip=True)
8
+ return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
9
+ _mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
10
+
11
+ def cifar(device=None):
12
+ tt = tar_extract(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
13
+ train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
14
+ test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
15
+ return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]
tinygrad/nn/optim.py CHANGED
@@ -126,7 +126,7 @@ class LAMB(Optimizer):
126
126
  def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
127
127
  super().__init__(params, lr)
128
128
  self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
129
- self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2])
129
+ self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
130
130
  self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
131
131
  self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
132
132
 
tinygrad/nn/state.py CHANGED
@@ -1,5 +1,5 @@
1
- import os, json, pathlib, zipfile, pickle, tarfile, struct
2
- from typing import Dict, Union, List, Optional, Any, Tuple
1
+ import os, json, pathlib, zipfile, pickle, tarfile, struct, functools
2
+ from typing import Dict, Union, List, Optional, Any, Tuple, Callable
3
3
  from tinygrad.tensor import Tensor
4
4
  from tinygrad.dtype import dtypes
5
5
  from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
@@ -16,7 +16,7 @@ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
16
16
  """
17
17
  t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
18
18
  json_len = t[0:8].bitcast(dtypes.int64).item()
19
- return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
19
+ return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
20
20
 
21
21
  def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
22
22
  """
@@ -129,6 +129,18 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
129
129
  else: v.replace(state_dict[k].to(v.device)).realize()
130
130
  if consume: del state_dict[k]
131
131
 
132
+ def tar_extract(fn:os.PathLike) -> Dict[str, Tensor]:
133
+ """
134
+ Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
135
+
136
+ ```python
137
+ tensors = nn.state.tar_extract("archive.tar")
138
+ ```
139
+ """
140
+ t = Tensor(pathlib.Path(fn))
141
+ with tarfile.open(fn, "r") as tar:
142
+ return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
143
+
132
144
  # torch support!
133
145
 
134
146
  def torch_load(fn:str) -> Dict[str, Tensor]:
@@ -159,8 +171,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
159
171
  if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
160
172
  assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
161
173
  # TODO: find a nice way to support all shapetracker on disktensors
162
- # TODO: BUG: a ".realize()" is needed here for 'GPU=1 python3 test/models/test_efficientnet.py TestEfficientNet.test_car'
163
- ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes).realize()
174
+ ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
164
175
 
165
176
  return ret.reshape(size)
166
177
 
@@ -168,7 +179,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
168
179
  def __setstate__(self, state): self.tensor = state[0]
169
180
 
170
181
  deserialized_objects: Dict[str, Any] = {}
171
- intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32,
182
+ intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
183
+ "IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
172
184
  "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
173
185
  whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
174
186
  class Dummy: pass
@@ -214,3 +226,76 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
214
226
  base_offset += 8 + lens[i]
215
227
  f.seek(rwd)
216
228
  return TorchPickle(f).load()
229
+
230
+ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
231
+ """
232
+ Converts ggml tensor data to a tinygrad tensor.
233
+
234
+ Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
235
+ Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14)
236
+ """
237
+ # https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356
238
+
239
+ # native types
240
+ if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
241
+ return t[:dtype.itemsize * n].bitcast(dtype)
242
+
243
+ def q_to_uint8(t: Tensor, b: int) -> Tensor:
244
+ # TODO: rewrite with arange?
245
+ shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
246
+ return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
247
+
248
+ # map to (number of elements, number of bytes)
249
+ if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None:
250
+ blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
251
+ if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
252
+ if ggml_type == 3:
253
+ d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
254
+ return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
255
+ if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
256
+ if ggml_type == 14:
257
+ xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
258
+ scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
259
+ d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
260
+ return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
261
+ raise ValueError(f"GGML type '{ggml_type}' is not supported!")
262
+
263
+ def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
264
+ """
265
+ Loads a gguf file from a tensor.
266
+
267
+ ```python
268
+ fn = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
269
+ gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
270
+ kv_data, state_dict = gguf_load(gguf_tensor)
271
+ ```
272
+ """
273
+ if tensor.dtype != dtypes.uint8 or len(tensor.shape) != 1: raise ValueError("GGUF tensor must be 1d and of dtype uint8!")
274
+ pos, read_buffer, rb_start, kv_data, state_dict = 0, memoryview(bytes()), 0, {}, {}
275
+ def read_bytes(n: int):
276
+ nonlocal pos, read_buffer, rb_start
277
+ if rb_start + len(read_buffer) < pos + n: rb_start, read_buffer = pos, tensor[pos:(pos+max(n, 1000_000))].data()
278
+ return read_buffer[pos-rb_start:(pos:=pos+n)-rb_start]
279
+ def read_unpack(fmt: str, n: int): return struct.unpack(fmt, read_bytes(n))[0]
280
+ def read_str(): return str(read_bytes(read_uint64()), "utf-8")
281
+ def read_arr():
282
+ reader, n = readers[read_int32()], read_uint64()
283
+ return [ reader() for _ in range(n) ]
284
+
285
+ readers: Dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t, f, nb in [ (0,"c",1),
286
+ (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
287
+ read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
288
+
289
+ magic, version, n_tensors, n_kv = read_bytes(4), read_int32(), read_int64(), read_int64()
290
+ if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
291
+ for _ in range(n_kv):
292
+ k, typ = read_str(), read_int32()
293
+ kv_data[k] = readers[typ]()
294
+
295
+ t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
296
+ alignment = kv_data.get("general.alignment", 32)
297
+ data_start = pos = pos + (alignment - pos % alignment if pos % alignment != 0 else 0)
298
+
299
+ for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
300
+
301
+ return kv_data, state_dict