tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/nn/optim.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
# sorted in order of increasing complexity
|
2
|
-
|
2
|
+
import itertools
|
3
|
+
from tinygrad.helpers import dedup, flatten, getenv, unwrap, FUSE_OPTIM
|
3
4
|
from tinygrad.tensor import Tensor
|
4
5
|
from tinygrad.dtype import dtypes, least_upper_dtype
|
5
6
|
|
@@ -7,7 +8,7 @@ class Optimizer:
|
|
7
8
|
"""
|
8
9
|
Base class for all optimizers.
|
9
10
|
"""
|
10
|
-
def __init__(self, params: list[Tensor], lr: float):
|
11
|
+
def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM):
|
11
12
|
# if it's None, but being put into an optimizer, set it to True
|
12
13
|
for x in params:
|
13
14
|
if x.requires_grad is None: x.requires_grad = True
|
@@ -16,9 +17,16 @@ class Optimizer:
|
|
16
17
|
assert len(self.params) != 0, "optimizer must have at least one param"
|
17
18
|
self.device = self.params[0].device
|
18
19
|
self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
20
|
+
self.fused = fused
|
19
21
|
# store lr in at least float32 precision
|
20
22
|
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
|
21
23
|
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
|
24
|
+
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
|
25
|
+
|
26
|
+
def _new_optim_param(self) -> list[Tensor]:
|
27
|
+
param_dtype = getenv("OPTIM_DTYPE", "float32")
|
28
|
+
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
|
29
|
+
return [Tensor.zeros(*t.shape, dtype=param_dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
22
30
|
|
23
31
|
def zero_grad(self):
|
24
32
|
"""
|
@@ -39,9 +47,17 @@ class Optimizer:
|
|
39
47
|
if not Tensor.training: raise RuntimeError(
|
40
48
|
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
|
41
49
|
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
|
42
|
-
|
43
|
-
|
44
|
-
|
50
|
+
if self.fused:
|
51
|
+
# optimizer fusion just concatenates all the buffers, runs the _step, then splits them back up
|
52
|
+
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
|
53
|
+
[Tensor.cat(*[unwrap(t.grad).flatten() for t in self.params], dim=0)])
|
54
|
+
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
|
55
|
+
else:
|
56
|
+
updated_params, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
|
57
|
+
for i, tt in enumerate(self.params): tt.assign(updated_params[i])
|
58
|
+
return extra+self.params+self.buffers
|
59
|
+
|
60
|
+
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: raise NotImplementedError
|
45
61
|
|
46
62
|
class OptimizerGroup(Optimizer):
|
47
63
|
"""
|
@@ -54,93 +70,108 @@ class OptimizerGroup(Optimizer):
|
|
54
70
|
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
|
55
71
|
def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()]
|
56
72
|
|
57
|
-
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0
|
58
|
-
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
|
73
|
+
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 it's just standard SGD.
|
74
|
+
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, fused=FUSE_OPTIM):
|
59
75
|
"""
|
60
76
|
Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
|
61
77
|
|
62
78
|
`classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
|
79
|
+
"""
|
80
|
+
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, fused=fused)
|
63
81
|
|
64
|
-
|
82
|
+
# Muon applies the newton schulz algorithm on gradient. also can include momentum, nesterov, and weight decay
|
83
|
+
def Muon(params: list[Tensor], lr=0.02, momentum=0.95, weight_decay=0.0, ns_steps=5, ns_params=(3.4445, -4.775, 2.0315),
|
84
|
+
nesterov=True, fused=FUSE_OPTIM):
|
65
85
|
"""
|
66
|
-
|
86
|
+
SGD with newton-schulz iteration and post momentum weight decay.
|
87
|
+
|
88
|
+
- Described: https://kellerjordan.github.io/posts/muon/
|
89
|
+
- Paper: https://arxiv.org/pdf/2502.16982
|
90
|
+
"""
|
91
|
+
assert not fused, "FUSE_OPTIM not allowed for Muon optimizer"
|
92
|
+
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_params, nesterov, classic=False, pre_wd=False, tcoef=0.0, fused=fused)
|
67
93
|
|
68
94
|
class LARS(Optimizer):
|
69
95
|
"""
|
70
96
|
Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.
|
71
97
|
|
72
|
-
- Described: https://paperswithcode.com/method/lars
|
73
98
|
- Paper: https://arxiv.org/abs/1708.03888v3
|
74
99
|
"""
|
75
|
-
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4,
|
76
|
-
|
77
|
-
|
78
|
-
self.
|
79
|
-
|
80
|
-
|
81
|
-
|
100
|
+
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, ns_steps=0, ns_params=None,
|
101
|
+
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, fused=FUSE_OPTIM):
|
102
|
+
super().__init__(params, lr, fused)
|
103
|
+
self.momentum, self.wd, self.ns_steps, self.ns_params = momentum, weight_decay, ns_steps, ns_params
|
104
|
+
self.nesterov, self.classic, self.pre_wd, self.tcoef = nesterov, classic, pre_wd, tcoef
|
105
|
+
self.b = self._new_optim_param() if self.momentum else []
|
106
|
+
|
107
|
+
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
108
|
+
ret = []
|
109
|
+
for i, (t, g) in enumerate(zip(params, grads)):
|
82
110
|
if self.tcoef != 0:
|
83
111
|
r1 = t.detach().square().sum().sqrt()
|
84
112
|
r2 = g.square().sum().sqrt()
|
85
|
-
r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
|
113
|
+
r:Tensor|float = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
|
86
114
|
else: r = 1.0
|
87
|
-
g = g + self.wd * t.detach()
|
115
|
+
if self.pre_wd and self.wd > 0: g = g + self.wd * t.detach()
|
88
116
|
# classic momentum does post learning rate update
|
89
117
|
if self.classic: g = g * r * self.lr
|
90
118
|
if self.momentum:
|
91
|
-
|
119
|
+
# TODO: this contiguous is required for correctness because self.b[i] becomes a non contiguous view
|
120
|
+
# the scheduler should detect this and just insert contiguous
|
121
|
+
self.b[i].assign(self.momentum * self.b[i].contiguous() + g) # NOTE: self.b[i] is zero on the first run, no if required
|
92
122
|
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
123
|
+
if self.ns_params: g = g.reshape(g.shape[0], -1).newton_schulz(self.ns_steps, self.ns_params).reshape(g.shape)
|
124
|
+
# muon does post momentum weight decay
|
125
|
+
if not self.pre_wd and self.wd > 0: t = t.detach() * (1.0 - self.wd * self.lr)
|
93
126
|
# popular momentum does pre learning rate update
|
94
127
|
if not self.classic: g = g * r * self.lr
|
95
|
-
|
96
|
-
return self.b
|
128
|
+
ret.append((t.detach() - g).cast(t.dtype))
|
129
|
+
return ret, self.b
|
97
130
|
|
98
|
-
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0
|
99
|
-
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
|
131
|
+
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 it's just Adam/W.
|
132
|
+
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, fused=FUSE_OPTIM):
|
100
133
|
"""
|
101
134
|
AdamW optimizer with optional weight decay.
|
102
135
|
|
103
|
-
- Described: https://paperswithcode.com/method/adamw
|
104
136
|
- Paper: https://arxiv.org/abs/1711.05101v3
|
105
137
|
"""
|
106
|
-
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)
|
107
|
-
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
138
|
+
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, fused=fused)
|
139
|
+
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, fused=FUSE_OPTIM):
|
108
140
|
"""
|
109
141
|
Adam optimizer.
|
110
142
|
|
111
|
-
- Described: https://paperswithcode.com/method/adam
|
112
143
|
- Paper: https://arxiv.org/abs/1412.6980
|
113
144
|
"""
|
114
|
-
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
|
145
|
+
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, fused=fused)
|
115
146
|
|
116
147
|
class LAMB(Optimizer):
|
117
148
|
"""
|
118
149
|
LAMB optimizer with optional weight decay.
|
119
150
|
|
120
|
-
- Described: https://paperswithcode.com/method/lamb
|
121
151
|
- Paper: https://arxiv.org/abs/1904.00962
|
122
152
|
"""
|
123
|
-
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
|
124
|
-
super().__init__(params, lr)
|
153
|
+
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, fused=FUSE_OPTIM):
|
154
|
+
super().__init__(params, lr, fused)
|
125
155
|
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
|
126
156
|
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
|
127
|
-
self.m =
|
128
|
-
self.v =
|
157
|
+
self.m = self._new_optim_param()
|
158
|
+
self.v = self._new_optim_param()
|
129
159
|
|
130
|
-
def
|
160
|
+
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
161
|
+
ret = []
|
131
162
|
self.b1_t *= self.b1
|
132
163
|
self.b2_t *= self.b2
|
133
|
-
for i, (t, g) in enumerate(zip(
|
134
|
-
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g)
|
135
|
-
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g))
|
164
|
+
for i, (t, g) in enumerate(zip(params, grads)):
|
165
|
+
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
|
166
|
+
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
|
136
167
|
m_hat = self.m[i] / (1.0 - self.b1_t)
|
137
168
|
v_hat = self.v[i] / (1.0 - self.b2_t)
|
138
169
|
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
139
170
|
if not self.adam:
|
140
171
|
r1 = t.detach().square().sum().sqrt()
|
141
172
|
r2 = up.square().sum().sqrt()
|
142
|
-
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
173
|
+
r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
143
174
|
else:
|
144
175
|
r = 1.0
|
145
|
-
|
146
|
-
return [self.b1_t, self.b2_t] + self.m + self.v
|
176
|
+
ret.append((t.detach() - self.lr * r * up).cast(t.dtype))
|
177
|
+
return ret, [self.b1_t, self.b2_t] + self.m + self.v
|
tinygrad/nn/state.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
|
2
2
|
from collections import OrderedDict
|
3
|
-
from typing import
|
3
|
+
from typing import Any, Callable, BinaryIO, Iterable
|
4
4
|
from tinygrad.tensor import Tensor
|
5
5
|
from tinygrad.dtype import dtypes
|
6
6
|
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
|
@@ -35,22 +35,22 @@ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dt
|
|
35
35
|
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
|
36
36
|
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
37
37
|
|
38
|
-
def accept_filename(func: Callable[[Tensor], T]) -> Callable[[
|
38
|
+
def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Tensor|str|pathlib.Path], T]:
|
39
39
|
@functools.wraps(func)
|
40
|
-
def wrapper(fn:
|
40
|
+
def wrapper(fn: Tensor|str|pathlib.Path) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
|
41
41
|
return wrapper
|
42
42
|
|
43
43
|
@accept_filename
|
44
44
|
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
|
45
45
|
"""
|
46
|
-
Loads a .safetensor file
|
46
|
+
Loads a .safetensor file, returning the source tensor, data start position, and metadata.
|
47
47
|
"""
|
48
48
|
data_start = int.from_bytes(t[0:8].data(), "little") + 8
|
49
49
|
return t, data_start, json.loads(t[8:data_start].data().tobytes())
|
50
50
|
|
51
|
-
def safe_load(fn:
|
51
|
+
def safe_load(fn:Tensor|str|pathlib.Path) -> dict[str, Tensor]:
|
52
52
|
"""
|
53
|
-
Loads a .safetensor file
|
53
|
+
Loads a .safetensor file, returning the `state_dict`.
|
54
54
|
|
55
55
|
```python
|
56
56
|
state_dict = nn.state.safe_load("test.safetensor")
|
@@ -61,9 +61,9 @@ def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
|
|
61
61
|
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
|
62
62
|
for k, v in metadata.items() if k != "__metadata__" }
|
63
63
|
|
64
|
-
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:
|
64
|
+
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=None):
|
65
65
|
"""
|
66
|
-
Saves a state_dict to disk in a .safetensor file with optional metadata.
|
66
|
+
Saves a `state_dict` to disk in a .safetensor file with optional metadata.
|
67
67
|
|
68
68
|
```python
|
69
69
|
t = Tensor([1, 2, 3])
|
@@ -87,7 +87,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any
|
|
87
87
|
|
88
88
|
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
|
89
89
|
"""
|
90
|
-
Returns a state_dict of the object, with optional prefix.
|
90
|
+
Returns a `state_dict` of the object, with optional prefix.
|
91
91
|
|
92
92
|
```python exec="true" source="above" session="tensor" result="python"
|
93
93
|
class Net:
|
@@ -124,9 +124,9 @@ def get_parameters(obj) -> list[Tensor]:
|
|
124
124
|
"""
|
125
125
|
return list(get_state_dict(obj).values())
|
126
126
|
|
127
|
-
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False) ->
|
127
|
+
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> list[Tensor]:
|
128
128
|
"""
|
129
|
-
Loads a state_dict into a model.
|
129
|
+
Loads a `state_dict` into a model. Return the loaded Tensors.
|
130
130
|
|
131
131
|
```python
|
132
132
|
class Net:
|
@@ -140,7 +140,9 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
|
|
140
140
|
```
|
141
141
|
"""
|
142
142
|
start_mem_used = GlobalCounters.mem_used
|
143
|
-
|
143
|
+
ret = []
|
144
|
+
with Timing("loaded weights in ",
|
145
|
+
lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose):
|
144
146
|
model_state_dict = get_state_dict(model)
|
145
147
|
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
146
148
|
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
@@ -152,15 +154,22 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
|
|
152
154
|
if v.shape != state_dict[k].shape:
|
153
155
|
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
|
154
156
|
if isinstance(v.device, tuple):
|
155
|
-
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
|
156
|
-
else: v.replace(state_dict[k].shard(v.device, v.
|
157
|
-
else: v.replace(state_dict[k].to(v.device))
|
157
|
+
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
|
158
|
+
else: v.replace(state_dict[k].shard(v.device, v.uop.axis))
|
159
|
+
else: v.replace(state_dict[k].to(v.device))
|
160
|
+
if realize: v.realize()
|
158
161
|
if consume: del state_dict[k]
|
162
|
+
ret.append(v)
|
163
|
+
return ret
|
159
164
|
|
160
165
|
@accept_filename
|
161
166
|
def tar_extract(t: Tensor) -> dict[str, Tensor]:
|
162
167
|
"""
|
163
|
-
|
168
|
+
```python
|
169
|
+
tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
|
170
|
+
```
|
171
|
+
|
172
|
+
Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
|
164
173
|
|
165
174
|
```python
|
166
175
|
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
|
@@ -174,14 +183,18 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
|
|
174
183
|
@accept_filename
|
175
184
|
def torch_load(t:Tensor) -> dict[str, Tensor]:
|
176
185
|
"""
|
177
|
-
|
186
|
+
```python
|
187
|
+
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
|
188
|
+
```
|
189
|
+
|
190
|
+
Loads a torch .pth file, returning the `state_dict`.
|
178
191
|
|
179
192
|
```python
|
180
193
|
state_dict = nn.state.torch_load("test.pth")
|
181
194
|
```
|
182
195
|
"""
|
183
|
-
offsets: dict[
|
184
|
-
lens: dict[
|
196
|
+
offsets: dict[str|int, int] = {}
|
197
|
+
lens: dict[str|int, int] = {}
|
185
198
|
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
|
186
199
|
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
187
200
|
lens[storage[2]] = storage[4] * storage[1].itemsize
|
@@ -292,13 +305,14 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
|
292
305
|
@accept_filename
|
293
306
|
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
|
294
307
|
"""
|
295
|
-
Loads a gguf file
|
308
|
+
Loads a .gguf file, returning the `kv_data` and `state_dict`.
|
296
309
|
|
297
310
|
```python
|
298
|
-
|
299
|
-
|
300
|
-
kv_data, state_dict = gguf_load(gguf_tensor)
|
311
|
+
gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
|
312
|
+
kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
|
301
313
|
```
|
314
|
+
|
315
|
+
NOTE: The provided tensor must be on a device that supports execution.
|
302
316
|
"""
|
303
317
|
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
|
304
318
|
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
|
tinygrad/renderer/__init__.py
CHANGED
@@ -1,46 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import
|
3
|
-
import functools,
|
4
|
-
from enum import Enum, auto
|
2
|
+
from typing import Callable, cast, TYPE_CHECKING
|
3
|
+
import functools, itertools
|
5
4
|
from dataclasses import dataclass, field, replace
|
6
5
|
from tinygrad.helpers import to_function_name, dedup, prod
|
7
|
-
from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
|
8
|
-
from tinygrad.dtype import
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
13
|
-
def __lt__(self, x:OptOps): return self.value < x.value
|
14
|
-
|
15
|
-
@dataclass(frozen=True, order=True)
|
16
|
-
class Opt:
|
17
|
-
op: OptOps
|
18
|
-
axis: Optional[int] = None
|
19
|
-
arg: Optional[int | tuple] = None
|
20
|
-
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
21
|
-
|
22
|
-
@dataclass(frozen=True)
|
23
|
-
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
24
|
-
dims: tuple[int,int,int] # N, M, K
|
25
|
-
threads: int # number of threads that construct the warp
|
26
|
-
elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
|
27
|
-
dtype_in: DType # dtype for A and B
|
28
|
-
dtype_out: DType # dtype for C and D
|
29
|
-
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifing kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
|
30
|
-
swizzle: tuple[Optional[tuple[tuple[int, ...], tuple[int, ...]]], Optional[tuple[tuple[int, ...], tuple[int, ...]]]] = (None, None)
|
31
|
-
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
|
32
|
-
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
|
33
|
-
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
|
34
|
-
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
35
|
-
def __post_init__(self):
|
36
|
-
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
|
37
|
-
assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), (
|
38
|
-
f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})")
|
39
|
-
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
|
40
|
-
assert 2**upcast_axes == self.elements_per_thread[2], (
|
41
|
-
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}")
|
42
|
-
assert all(len(perm[0]) == local_axes and len(perm[1]) == reduce_axes + upcast_axes for perm in self.swizzle if perm), (
|
43
|
-
f"swizzle perm should be of len (({local_axes})({reduce_axes + upcast_axes}))")
|
6
|
+
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
|
7
|
+
from tinygrad.dtype import AddrSpace, PtrDType
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from tinygrad.codegen.opt.tc import TensorCore
|
10
|
+
from tinygrad.codegen.opt.kernel import Opt
|
44
11
|
|
45
12
|
@dataclass(frozen=True)
|
46
13
|
class Estimates:
|
@@ -61,19 +28,23 @@ class Estimates:
|
|
61
28
|
dont_count: set[UOp] = set()
|
62
29
|
if ignore_indexing:
|
63
30
|
for u in uops:
|
64
|
-
if u.op in {Ops.LOAD, Ops.STORE}:
|
65
|
-
dont_count = dont_count.union(u.src[0].toposort)
|
66
|
-
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort)
|
31
|
+
if u.op in {Ops.LOAD, Ops.STORE} and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
32
|
+
dont_count = dont_count.union(u.src[0].toposort())
|
33
|
+
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
|
67
34
|
elif u.op is Ops.IF:
|
68
|
-
dont_count = dont_count.union(u.src[0].toposort)
|
35
|
+
dont_count = dont_count.union(u.src[0].toposort())
|
69
36
|
for u in uops:
|
70
37
|
if u.op is Ops.RANGE:
|
71
38
|
mult_stack.append(mults)
|
72
|
-
mults *= (
|
39
|
+
mults *= cast(sint, u.src[0].ssimplify())
|
40
|
+
# SPECIAL are already counted in mults
|
41
|
+
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
|
73
42
|
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
|
74
43
|
elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
75
|
-
elif u.op is Ops.LOAD
|
76
|
-
|
44
|
+
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
45
|
+
lds += u.dtype.itemsize * mults
|
46
|
+
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
47
|
+
lds += u.src[1].dtype.itemsize * mults
|
77
48
|
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
78
49
|
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
79
50
|
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
|
@@ -84,13 +55,11 @@ class ProgramSpec:
|
|
84
55
|
src:str
|
85
56
|
device:str
|
86
57
|
ast:UOp # save the base ast (this is method cache key)
|
87
|
-
uops:
|
88
|
-
applied_opts:Optional[list[Opt]]=None
|
89
|
-
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
|
58
|
+
uops:list[UOp]|None=None
|
90
59
|
|
91
60
|
# filled in from uops (if we have uops)
|
92
|
-
global_size:
|
93
|
-
local_size:
|
61
|
+
global_size:list[int]|None=None
|
62
|
+
local_size:list[int]|None=None
|
94
63
|
vars:list[Variable]=field(default_factory=list)
|
95
64
|
globals:list[int]=field(default_factory=list)
|
96
65
|
outs:list[int]=field(default_factory=list)
|
@@ -103,19 +72,26 @@ class ProgramSpec:
|
|
103
72
|
for u in self.uops:
|
104
73
|
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
|
105
74
|
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
|
106
|
-
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
107
|
-
if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
75
|
+
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL])
|
76
|
+
if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL])
|
108
77
|
if u.op is Ops.SPECIAL:
|
109
78
|
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
110
79
|
if u.arg[0][0] == 'i': self.local_size = None
|
111
80
|
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
|
112
|
-
|
113
|
-
special_size[int(u.arg[0][-1])] = u.arg[1]
|
81
|
+
if special_size is not None: special_size[int(u.arg[0][-1])] = u.arg[1]
|
114
82
|
self.vars = sorted(self.vars, key=lambda v: v.arg)
|
115
83
|
self.outs = sorted(dedup(self.outs))
|
116
84
|
self.ins = sorted(dedup(self.ins))
|
117
85
|
self._ran_post_init = True
|
118
86
|
|
87
|
+
@functools.cached_property
|
88
|
+
def mem_estimate(self) -> sint:
|
89
|
+
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
90
|
+
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
91
|
+
return sum(max(x.src[0].dtype.nbytes() for x in group)
|
92
|
+
for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in {Ops.LOAD, Ops.STORE} and x.src[0].base.op is Ops.DEFINE_GLOBAL],
|
93
|
+
key=lambda x: (x.op, x.src[0].base.arg)))
|
94
|
+
|
119
95
|
@functools.cached_property
|
120
96
|
def estimates(self) -> Estimates:
|
121
97
|
return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)
|
@@ -123,6 +99,10 @@ class ProgramSpec:
|
|
123
99
|
@functools.cached_property
|
124
100
|
def function_name(self) -> str: return to_function_name(self.name)
|
125
101
|
|
102
|
+
@property
|
103
|
+
def applied_opts(self) -> tuple[Opt, ...]|None: return self.uops[-1].arg.applied_opts if \
|
104
|
+
self.uops is not None and self.uops[-1].op is Ops.SINK and self.uops[-1].arg is not None else None
|
105
|
+
|
126
106
|
def launch_dims(self, var_vals:dict[Variable, int]):
|
127
107
|
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
128
108
|
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
@@ -136,12 +116,12 @@ class Renderer:
|
|
136
116
|
has_local: bool = True
|
137
117
|
has_shared: bool = True
|
138
118
|
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
|
139
|
-
global_max:
|
140
|
-
local_max:
|
119
|
+
global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
120
|
+
local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
141
121
|
shared_max: int = 32768
|
142
122
|
tensor_cores: list[TensorCore] = []
|
143
|
-
pre_matcher:
|
144
|
-
extra_matcher:
|
123
|
+
pre_matcher: PatternMatcher|None = None
|
124
|
+
extra_matcher: PatternMatcher|None = None
|
145
125
|
code_for_op: dict[Ops, Callable] = {}
|
146
126
|
|
147
127
|
def __reduce__(self): return self.__class__, ()
|