tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/nn/__init__.py CHANGED
@@ -5,6 +5,30 @@ from tinygrad.helpers import prod
5
5
  from tinygrad.nn import optim, state # noqa: F401
6
6
 
7
7
  class BatchNorm2d:
8
+ """
9
+ Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension).
10
+
11
+ - Described: https://paperswithcode.com/method/batch-normalization
12
+ - Paper: https://arxiv.org/abs/1502.03167v3
13
+
14
+ See: `Tensor.batchnorm`
15
+
16
+ ```python exec="true" session="tensor"
17
+ from tinygrad import Tensor, dtypes, nn
18
+ import numpy as np
19
+ np.set_printoptions(precision=4)
20
+ ```
21
+
22
+ ```python exec="true" source="above" session="tensor" result="python"
23
+ norm = nn.BatchNorm2d(3)
24
+ t = Tensor.rand(2, 3, 4, 4)
25
+ print(t.mean().item(), t.std().item())
26
+ ```
27
+ ```python exec="true" source="above" session="tensor" result="python"
28
+ t = norm(t)
29
+ print(t.mean().item(), t.std().item())
30
+ ```
31
+ """
8
32
  def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
9
33
  self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
10
34
 
@@ -20,7 +44,7 @@ class BatchNorm2d:
20
44
  # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
21
45
  # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
22
46
  batch_mean = x.mean(axis=(0,2,3))
23
- y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
47
+ y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
24
48
  batch_var = (y*y).mean(axis=(0,2,3))
25
49
  batch_invstd = batch_var.add(self.eps).pow(-0.5)
26
50
 
@@ -38,9 +62,39 @@ class BatchNorm2d:
38
62
 
39
63
  # TODO: these Conv lines are terrible
40
64
  def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
65
+ """
66
+ Applies a 1D convolution over an input signal composed of several input planes.
67
+
68
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d
69
+
70
+ ```python exec="true" source="above" session="tensor" result="python"
71
+ conv = nn.Conv1d(1, 1, 3)
72
+ t = Tensor.rand(1, 1, 4)
73
+ print(t.numpy())
74
+ ```
75
+ ```python exec="true" source="above" session="tensor" result="python"
76
+ t = conv(t)
77
+ print(t.numpy())
78
+ ```
79
+ """
41
80
  return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
42
81
 
43
82
  class Conv2d:
83
+ """
84
+ Applies a 2D convolution over an input signal composed of several input planes.
85
+
86
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d
87
+
88
+ ```python exec="true" source="above" session="tensor" result="python"
89
+ conv = nn.Conv2d(1, 1, 3)
90
+ t = Tensor.rand(1, 1, 4, 4)
91
+ print(t.numpy())
92
+ ```
93
+ ```python exec="true" source="above" session="tensor" result="python"
94
+ t = conv(t)
95
+ print(t.numpy())
96
+ ```
97
+ """
44
98
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
45
99
  self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
46
100
  self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
@@ -55,9 +109,39 @@ class Conv2d:
55
109
  return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
56
110
 
57
111
  def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
112
+ """
113
+ Applies a 1D transposed convolution operator over an input signal composed of several input planes.
114
+
115
+ See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d
116
+
117
+ ```python exec="true" source="above" session="tensor" result="python"
118
+ conv = nn.ConvTranspose1d(1, 1, 3)
119
+ t = Tensor.rand(1, 1, 4)
120
+ print(t.numpy())
121
+ ```
122
+ ```python exec="true" source="above" session="tensor" result="python"
123
+ t = conv(t)
124
+ print(t.numpy())
125
+ ```
126
+ """
58
127
  return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
59
128
 
60
129
  class ConvTranspose2d(Conv2d):
130
+ """
131
+ Applies a 2D transposed convolution operator over an input image.
132
+
133
+ See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d
134
+
135
+ ```python exec="true" source="above" session="tensor" result="python"
136
+ conv = nn.ConvTranspose2d(1, 1, 3)
137
+ t = Tensor.rand(1, 1, 4, 4)
138
+ print(t.numpy())
139
+ ```
140
+ ```python exec="true" source="above" session="tensor" result="python"
141
+ t = conv(t)
142
+ print(t.numpy())
143
+ ```
144
+ """
61
145
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
62
146
  super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
63
147
  self.output_padding = output_padding
@@ -70,6 +154,21 @@ class ConvTranspose2d(Conv2d):
70
154
  return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
71
155
 
72
156
  class Linear:
157
+ """
158
+ Applies a linear transformation to the incoming data.
159
+
160
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Linear
161
+
162
+ ```python exec="true" source="above" session="tensor" result="python"
163
+ lin = nn.Linear(3, 4)
164
+ t = Tensor.rand(2, 3)
165
+ print(t.numpy())
166
+ ```
167
+ ```python exec="true" source="above" session="tensor" result="python"
168
+ t = lin(t)
169
+ print(t.numpy())
170
+ ```
171
+ """
73
172
  def __init__(self, in_features, out_features, bias=True):
74
173
  # TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
75
174
  self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
@@ -80,6 +179,22 @@ class Linear:
80
179
  return x.linear(self.weight.transpose(), self.bias)
81
180
 
82
181
  class GroupNorm:
182
+ """
183
+ Applies Group Normalization over a mini-batch of inputs.
184
+
185
+ - Described: https://paperswithcode.com/method/group-normalization
186
+ - Paper: https://arxiv.org/abs/1803.08494v3
187
+
188
+ ```python exec="true" source="above" session="tensor" result="python"
189
+ norm = nn.GroupNorm(2, 12)
190
+ t = Tensor.rand(2, 12, 4, 4) * 2 + 1
191
+ print(t.mean().item(), t.std().item())
192
+ ```
193
+ ```python exec="true" source="above" session="tensor" result="python"
194
+ t = norm(t)
195
+ print(t.mean().item(), t.std().item())
196
+ ```
197
+ """
83
198
  def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
84
199
  self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
85
200
  self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
@@ -95,6 +210,22 @@ class GroupNorm:
95
210
  return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
96
211
 
97
212
  class InstanceNorm:
213
+ """
214
+ Applies Instance Normalization over a mini-batch of inputs.
215
+
216
+ - Described: https://paperswithcode.com/method/instance-normalization
217
+ - Paper: https://arxiv.org/abs/1607.08022v3
218
+
219
+ ```python exec="true" source="above" session="tensor" result="python"
220
+ norm = nn.InstanceNorm(3)
221
+ t = Tensor.rand(2, 3, 4, 4) * 2 + 1
222
+ print(t.mean().item(), t.std().item())
223
+ ```
224
+ ```python exec="true" source="above" session="tensor" result="python"
225
+ t = norm(t)
226
+ print(t.mean().item(), t.std().item())
227
+ ```
228
+ """
98
229
  def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
99
230
  self.num_features, self.eps = num_features, eps
100
231
  self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
@@ -106,6 +237,22 @@ class InstanceNorm:
106
237
  return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
107
238
 
108
239
  class LayerNorm:
240
+ """
241
+ Applies Layer Normalization over a mini-batch of inputs.
242
+
243
+ - Described: https://paperswithcode.com/method/layer-normalization
244
+ - Paper: https://arxiv.org/abs/1607.06450v1
245
+
246
+ ```python exec="true" source="above" session="tensor" result="python"
247
+ norm = nn.LayerNorm(3)
248
+ t = Tensor.rand(2, 5, 3) * 2 + 1
249
+ print(t.mean().item(), t.std().item())
250
+ ```
251
+ ```python exec="true" source="above" session="tensor" result="python"
252
+ t = norm(t)
253
+ print(t.mean().item(), t.std().item())
254
+ ```
255
+ """
109
256
  def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
110
257
  self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
111
258
  self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
@@ -118,16 +265,40 @@ class LayerNorm:
118
265
  return x * self.weight + self.bias
119
266
 
120
267
  class LayerNorm2d(LayerNorm):
268
+ """
269
+ Applies Layer Normalization over a mini-batch of 2D inputs.
270
+
271
+ See: `LayerNorm`
272
+
273
+ ```python exec="true" source="above" session="tensor" result="python"
274
+ norm = nn.LayerNorm2d(3)
275
+ t = Tensor.rand(2, 3, 4, 4) * 2 + 1
276
+ print(t.mean().item(), t.std().item())
277
+ ```
278
+ ```python exec="true" source="above" session="tensor" result="python"
279
+ t = norm(t)
280
+ print(t.mean().item(), t.std().item())
281
+ ```
282
+ """
121
283
  def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
122
284
 
123
285
  class Embedding:
286
+ """
287
+ A simple lookup table that stores embeddings of a fixed dictionary and size.
288
+
289
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Embedding
290
+
291
+ ```python exec="true" source="above" session="tensor" result="python"
292
+ emb = nn.Embedding(10, 3)
293
+ print(emb(Tensor([1, 2, 3, 1])).numpy())
294
+ ```
295
+ """
124
296
  def __init__(self, vocab_size:int, embed_size:int):
125
- self.vocab_size, self.embed_size = vocab_size, embed_size
126
- self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
297
+ self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
127
298
 
128
299
  def __call__(self, idx:Tensor) -> Tensor:
129
- if not hasattr(self, 'vocab_counter'):
130
- self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False, device=self.weight.device).reshape(1, 1, self.vocab_size)
131
- batch_size, seqlen = idx.shape
132
- if seqlen == 0: return Tensor.empty(batch_size, 0, self.embed_size, device=self.weight.device)
133
- return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight
300
+ if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
301
+ arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
302
+ if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
303
+ arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
304
+ return (arange == idx).mul(vals).sum(2)
@@ -0,0 +1,7 @@
1
+ import gzip
2
+ from tinygrad import Tensor, fetch
3
+
4
+ def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
5
+ def mnist():
6
+ return _fetch_mnist("train-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("train-labels-idx1-ubyte.gz", 8), \
7
+ _fetch_mnist("t10k-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("t10k-labels-idx1-ubyte.gz", 8)
tinygrad/nn/optim.py CHANGED
@@ -1,9 +1,13 @@
1
1
  # sorted in order of increasing complexity
2
2
  from typing import List
3
- from tinygrad.helpers import dedup, getenv
3
+ from tinygrad.helpers import dedup, flatten, getenv
4
4
  from tinygrad.tensor import Tensor
5
+ from tinygrad.dtype import dtypes, least_upper_dtype
5
6
 
6
7
  class Optimizer:
8
+ """
9
+ Base class for all optimizers.
10
+ """
7
11
  def __init__(self, params: List[Tensor], lr: float):
8
12
  # if it's None, but being put into an optimizer, set it to True
9
13
  for x in params:
@@ -13,54 +17,128 @@ class Optimizer:
13
17
  assert len(self.params) != 0, "optimizer must have at least one param"
14
18
  self.device = self.params[0].device
15
19
  self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
16
- self.lr = lr if getenv("CONST_LR") else Tensor([lr], requires_grad=False, device=self.device).contiguous()
20
+ # store lr in at least float32 precision
21
+ self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
22
+ dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
17
23
 
18
24
  def zero_grad(self):
25
+ """
26
+ Zeroes the gradients of all the parameters.
27
+ """
19
28
  for param in self.params: param.grad = None
20
29
 
21
- def realize(self, extra=None):
22
- # NOTE: in extra is too late for most of the params due to issues with assign
23
- Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
30
+ def step(self):
31
+ """
32
+ Performs a single optimization step.
33
+ """
34
+ Tensor.realize(*self.schedule_step())
35
+ def schedule_step(self) -> List[Tensor]:
36
+ """
37
+ Returns the tensors that need to be realized to perform a single optimization step.
38
+ """
39
+ assert Tensor.training, (
40
+ f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
41
+ - help: Consider setting Tensor.training=True before calling Optimizer.step().""")
42
+ return self._step()+self.params+self.buffers
43
+ def _step(self) -> List[Tensor]: raise NotImplementedError
24
44
 
25
- class SGD(Optimizer):
26
- def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
45
+ class OptimizerGroup(Optimizer):
46
+ """
47
+ Combines multiple optimizers into one.
48
+ """
49
+ def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called
50
+ self.optimizers = optimizers
51
+ self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
52
+ def __getitem__(self, i): return self.optimizers[i]
53
+ def zero_grad(self): [o.zero_grad() for o in self.optimizers]
54
+ def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]
55
+
56
+ # LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
57
+ def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
58
+ """
59
+ Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
60
+
61
+ `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
62
+
63
+ - Described: https://paperswithcode.com/method/sgd
64
+ """
65
+ return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0)
66
+
67
+ class LARS(Optimizer):
68
+ """
69
+ Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.
70
+
71
+ - Described: https://paperswithcode.com/method/lars
72
+ - Paper: https://arxiv.org/abs/1708.03888v3
73
+ """
74
+ def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
27
75
  super().__init__(params, lr)
28
- self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
29
- self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
76
+ self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
77
+ self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
30
78
 
31
- # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
32
- def step(self) -> None:
79
+ def _step(self) -> List[Tensor]:
33
80
  for i, t in enumerate(self.params):
34
81
  assert t.grad is not None
35
- # this is needed since the grads can form a "diamond"
82
+ # contiguous is needed since the grads can allegedly form a "diamond"
36
83
  # TODO: fix this in lazy.py
37
- t.grad.realize()
38
- g = t.grad + self.wd * t.detach()
84
+ g = t.grad.contiguous()
85
+ if self.tcoef != 0:
86
+ r1 = t.detach().square().sum().sqrt()
87
+ r2 = g.square().sum().sqrt()
88
+ r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
89
+ else: r = 1.0
90
+ g = g + self.wd * t.detach()
91
+ # classic momentum does post learning rate update
92
+ if self.classic: g = g * r * self.lr
39
93
  if self.momentum:
40
94
  self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
41
95
  g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
42
- t.assign(t.detach() - g * self.lr)
43
- self.realize(self.b)
96
+ # popular momentum does pre learning rate update
97
+ if not self.classic: g = g * r * self.lr
98
+ t.assign((t.detach() - g).cast(t.dtype))
99
+ return self.b
44
100
 
45
101
  # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
46
- def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True)
47
- def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
102
+ def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
103
+ """
104
+ AdamW optimizer with optional weight decay.
105
+
106
+ - Described: https://paperswithcode.com/method/adamw
107
+ - Paper: https://arxiv.org/abs/1711.05101v3
108
+ """
109
+ return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)
110
+ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
111
+ """
112
+ Adam optimizer.
113
+
114
+ - Described: https://paperswithcode.com/method/adam
115
+ - Paper: https://arxiv.org/abs/1412.6980
116
+ """
117
+ return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
48
118
 
49
119
  class LAMB(Optimizer):
50
- def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
120
+ """
121
+ LAMB optimizer with optional weight decay.
122
+
123
+ - Described: https://paperswithcode.com/method/lamb
124
+ - Paper: https://arxiv.org/abs/1904.00962
125
+ """
126
+ def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
51
127
  super().__init__(params, lr)
52
- self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], device=self.device, requires_grad=False).realize()
53
- self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
54
- self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
128
+ self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
129
+ self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2])
130
+ self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
131
+ self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
55
132
 
56
- def step(self) -> None:
57
- self.t.assign(self.t + 1)
133
+ def _step(self) -> List[Tensor]:
134
+ self.b1_t *= self.b1
135
+ self.b2_t *= self.b2
58
136
  for i, t in enumerate(self.params):
59
137
  assert t.grad is not None
60
138
  self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
61
139
  self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
62
- m_hat = self.m[i] / (1.0 - self.b1**self.t)
63
- v_hat = self.v[i] / (1.0 - self.b2**self.t)
140
+ m_hat = self.m[i] / (1.0 - self.b1_t)
141
+ v_hat = self.v[i] / (1.0 - self.b2_t)
64
142
  up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
65
143
  if not self.adam:
66
144
  r1 = t.detach().square().sum().sqrt()
@@ -68,5 +146,5 @@ class LAMB(Optimizer):
68
146
  r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
69
147
  else:
70
148
  r = 1.0
71
- t.assign(t.detach() - self.lr * r * up)
72
- self.realize([self.t] + self.m + self.v)
149
+ t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
150
+ return [self.b1_t, self.b2_t] + self.m + self.v
tinygrad/nn/state.py CHANGED
@@ -2,31 +2,49 @@ import os, json, pathlib, zipfile, pickle, tarfile, struct
2
2
  from tqdm import tqdm
3
3
  from typing import Dict, Union, List, Optional, Any, Tuple
4
4
  from tinygrad.tensor import Tensor
5
- from tinygrad.ops import GlobalCounters
6
5
  from tinygrad.dtype import dtypes
7
- from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap
6
+ from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters
8
7
  from tinygrad.shape.view import strides_for_shape
8
+ from tinygrad.multi import MultiLazyBuffer
9
9
 
10
- safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64,
11
- "F64": dtypes.double, "B": dtypes.bool, "I16": dtypes.short, "U16": dtypes.ushort, "UI": dtypes.uint, "UL": dtypes.ulong}
10
+ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
11
+ "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
12
12
  inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
13
13
 
14
14
  def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
15
+ """
16
+ Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
17
+ """
15
18
  t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
16
- json_len = t[0:1].cast(dtypes.int64).numpy()[0]
17
- return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()))
19
+ json_len = t[0:8].bitcast(dtypes.int64).item()
20
+ return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
18
21
 
19
22
  def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
23
+ """
24
+ Loads a .safetensor file from disk, returning the state_dict.
25
+
26
+ ```python
27
+ state_dict = nn.state.safe_load("test.safetensor")
28
+ ```
29
+ """
20
30
  t, json_len, metadata = safe_load_metadata(fn)
21
31
  ret = {}
22
32
  for k,v in metadata.items():
23
33
  if k == "__metadata__": continue
24
34
  dtype = safe_dtypes[v['dtype']]
25
- sz = (v['data_offsets'][1]-v['data_offsets'][0])//dtype.itemsize
26
- ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].cast(dtype).reshape(v['shape'])
35
+ sz = (v['data_offsets'][1]-v['data_offsets'][0])
36
+ ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
27
37
  return ret
28
38
 
29
39
  def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
40
+ """
41
+ Saves a state_dict to disk in a .safetensor file with optional metadata.
42
+
43
+ ```python
44
+ t = nn.Tensor([1, 2, 3])
45
+ nn.state.safe_save({'t':t}, "test.safetensor")
46
+ ```
47
+ """
30
48
  headers, offset = {}, 0
31
49
  if metadata: headers['__metadata__'] = metadata
32
50
  for k,v in tensors.items():
@@ -36,14 +54,27 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
36
54
  j += "\x20"*((8-len(j)%8)%8)
37
55
  pathlib.Path(fn).unlink(missing_ok=True)
38
56
  t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
39
- t[0:1].cast(dtypes.int64).assign([len(j)])
40
- t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu"))
57
+ t[0:8].bitcast(dtypes.int64).assign([len(j)])
58
+ t[8:8+len(j)].assign(list(j.encode('utf-8')))
41
59
  for k,v in safe_load(t).items(): v.assign(tensors[k])
42
60
 
43
61
  # state dict
44
62
 
45
63
  from collections import OrderedDict
46
64
  def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
65
+ """
66
+ Returns a state_dict of the object, with optional prefix.
67
+
68
+ ```python exec="true" source="above" session="tensor" result="python"
69
+ class Net:
70
+ def __init__(self):
71
+ self.l1 = nn.Linear(4, 5)
72
+ self.l2 = nn.Linear(5, 6)
73
+
74
+ net = Net()
75
+ print(nn.state.get_state_dict(net).keys())
76
+ ```
77
+ """
47
78
  if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
48
79
  if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
49
80
  if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
@@ -54,9 +85,35 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
54
85
  elif isinstance(obj, dict):
55
86
  for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
56
87
  return state_dict
57
- def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
88
+ def get_parameters(obj) -> List[Tensor]:
89
+ """
90
+ ```python exec="true" source="above" session="tensor" result="python"
91
+ class Net:
92
+ def __init__(self):
93
+ self.l1 = nn.Linear(4, 5)
94
+ self.l2 = nn.Linear(5, 6)
58
95
 
59
- def load_state_dict(model, state_dict, strict=True, verbose=True):
96
+ net = Net()
97
+ print(len(nn.state.get_parameters(net)))
98
+ ```
99
+ """
100
+ return list(get_state_dict(obj).values())
101
+
102
+ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
103
+ """
104
+ Loads a state_dict into a model.
105
+
106
+ ```python
107
+ class Net:
108
+ def __init__(self):
109
+ self.l1 = nn.Linear(4, 5)
110
+ self.l2 = nn.Linear(5, 6)
111
+
112
+ net = Net()
113
+ state_dict = nn.state.get_state_dict(net)
114
+ nn.state.load_state_dict(net, state_dict)
115
+ ```
116
+ """
60
117
  start_mem_used = GlobalCounters.mem_used
61
118
  with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501
62
119
  model_state_dict = get_state_dict(model)
@@ -67,11 +124,22 @@ def load_state_dict(model, state_dict, strict=True, verbose=True):
67
124
  if k not in state_dict and not strict:
68
125
  if DEBUG >= 1: print(f"WARNING: not loading {k}")
69
126
  continue
70
- v.assign(state_dict[k].to(v.device)).realize()
127
+ if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
128
+ if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
129
+ else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
130
+ else: v.replace(state_dict[k].to(v.device)).realize()
131
+ if consume: del state_dict[k]
71
132
 
72
133
  # torch support!
73
134
 
74
135
  def torch_load(fn:str) -> Dict[str, Tensor]:
136
+ """
137
+ Loads a torch .pth file from disk.
138
+
139
+ ```python
140
+ state_dict = nn.state.torch_load("test.pth")
141
+ ```
142
+ """
75
143
  t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
76
144
 
77
145
  offsets: Dict[Union[str, int], int] = {}
@@ -81,7 +149,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
81
149
  lens[storage[2]] = storage[4] * storage[1].itemsize
82
150
  if storage[2] not in offsets: return None
83
151
  byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
84
- ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
152
+ ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
85
153
 
86
154
  # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
87
155
  shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
@@ -89,10 +157,11 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
89
157
  if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
90
158
  intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
91
159
  assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
92
- if DEBUG >= 3: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}")
93
- assert storage[1] != dtypes.bfloat16, "can't CPU permute BF16"
160
+ if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
161
+ assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
94
162
  # TODO: find a nice way to support all shapetracker on disktensors
95
- ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
163
+ # TODO: BUG: a ".realize()" is needed here for 'GPU=1 python3 test/models/test_efficientnet.py TestEfficientNet.test_car'
164
+ ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes).realize()
96
165
 
97
166
  return ret.reshape(size)
98
167