tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/nn/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import math
|
3
|
-
from
|
4
|
-
from tinygrad.
|
3
|
+
from tinygrad.tensor import Tensor
|
4
|
+
from tinygrad.dtype import dtypes
|
5
5
|
from tinygrad.device import is_dtype_supported
|
6
6
|
from tinygrad.helpers import prod, make_tuple, flatten
|
7
7
|
from tinygrad.nn import optim, state, datasets # noqa: F401
|
@@ -34,14 +34,14 @@ class BatchNorm:
|
|
34
34
|
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
35
35
|
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
36
36
|
|
37
|
-
self.weight:
|
38
|
-
self.bias:
|
37
|
+
self.weight: Tensor|None = Tensor.ones(sz) if affine else None
|
38
|
+
self.bias: Tensor|None = Tensor.zeros(sz) if affine else None
|
39
39
|
|
40
40
|
self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
|
41
41
|
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
42
42
|
|
43
|
-
def calc_stats(self, x:Tensor) ->
|
44
|
-
shape_mask:
|
43
|
+
def calc_stats(self, x:Tensor) -> tuple[Tensor, Tensor]:
|
44
|
+
shape_mask: list[int] = [1, -1, *([1]*(x.ndim-2))]
|
45
45
|
if self.track_running_stats and not Tensor.training: return self.running_mean, self.running_var.reshape(shape=shape_mask).expand(x.shape)
|
46
46
|
# This requires two full memory accesses to x
|
47
47
|
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
@@ -56,12 +56,12 @@ class BatchNorm:
|
|
56
56
|
# NOTE: wow, this is done all throughout training in most PyTorch models
|
57
57
|
if self.track_running_stats and Tensor.training:
|
58
58
|
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
|
59
|
-
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum *
|
59
|
+
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * x.numel()/(x.numel()-x.shape[1]) * batch_var.detach())
|
60
60
|
self.num_batches_tracked += 1
|
61
61
|
return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
|
62
62
|
BatchNorm2d = BatchNorm3d = BatchNorm
|
63
63
|
|
64
|
-
def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:
|
64
|
+
def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:int|str=0, dilation=1, groups=1, bias=True) -> Conv2d:
|
65
65
|
"""
|
66
66
|
Applies a 1D convolution over an input signal composed of several input planes.
|
67
67
|
|
@@ -95,7 +95,7 @@ class Conv2d:
|
|
95
95
|
print(t.numpy())
|
96
96
|
```
|
97
97
|
"""
|
98
|
-
def __init__(self, in_channels:int, out_channels:int, kernel_size:
|
98
|
+
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride=1, padding:int|tuple[int, ...]|str=0,
|
99
99
|
dilation=1, groups=1, bias=True):
|
100
100
|
self.kernel_size = make_tuple(kernel_size, 2)
|
101
101
|
if isinstance(padding, str):
|
@@ -106,10 +106,9 @@ class Conv2d:
|
|
106
106
|
self.stride, self.dilation, self.groups, self.padding = stride, dilation, groups, padding
|
107
107
|
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
108
108
|
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
109
|
-
self.bias:
|
109
|
+
self.bias: Tensor|None = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
|
110
110
|
|
111
|
-
def __call__(self, x:Tensor) -> Tensor:
|
112
|
-
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
111
|
+
def __call__(self, x:Tensor) -> Tensor: return x.conv2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding)
|
113
112
|
|
114
113
|
def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
|
115
114
|
groups=1, bias=True) -> ConvTranspose2d:
|
@@ -146,7 +145,7 @@ class ConvTranspose2d(Conv2d):
|
|
146
145
|
print(t.numpy())
|
147
146
|
```
|
148
147
|
"""
|
149
|
-
def __init__(self, in_channels:int, out_channels:int, kernel_size:
|
148
|
+
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride=1, padding=0, output_padding=0,
|
150
149
|
dilation=1, groups=1, bias=True):
|
151
150
|
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
152
151
|
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
@@ -154,8 +153,7 @@ class ConvTranspose2d(Conv2d):
|
|
154
153
|
self.output_padding = output_padding
|
155
154
|
|
156
155
|
def __call__(self, x:Tensor) -> Tensor:
|
157
|
-
return x.conv_transpose2d(self.weight, self.bias,
|
158
|
-
dilation=self.dilation, groups=self.groups)
|
156
|
+
return x.conv_transpose2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding, self.output_padding)
|
159
157
|
|
160
158
|
class Linear:
|
161
159
|
"""
|
@@ -178,8 +176,7 @@ class Linear:
|
|
178
176
|
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
|
179
177
|
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
180
178
|
|
181
|
-
def __call__(self, x:Tensor) -> Tensor:
|
182
|
-
return x.linear(self.weight.transpose(), self.bias)
|
179
|
+
def __call__(self, x:Tensor) -> Tensor: return x.linear(self.weight.transpose(), self.bias)
|
183
180
|
|
184
181
|
class GroupNorm:
|
185
182
|
"""
|
@@ -200,8 +197,8 @@ class GroupNorm:
|
|
200
197
|
"""
|
201
198
|
def __init__(self, num_groups:int, num_channels:int, eps=1e-5, affine=True):
|
202
199
|
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
203
|
-
self.weight:
|
204
|
-
self.bias:
|
200
|
+
self.weight: Tensor|None = Tensor.ones(num_channels) if affine else None
|
201
|
+
self.bias: Tensor|None = Tensor.zeros(num_channels) if affine else None
|
205
202
|
|
206
203
|
def __call__(self, x:Tensor) -> Tensor:
|
207
204
|
# reshape for layernorm to work as group norm
|
@@ -210,7 +207,7 @@ class GroupNorm:
|
|
210
207
|
|
211
208
|
if self.weight is None or self.bias is None: return x
|
212
209
|
# elementwise_affine on channels
|
213
|
-
return x * self.weight.reshape(1, -1, *[1] * (
|
210
|
+
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
|
214
211
|
|
215
212
|
class InstanceNorm:
|
216
213
|
"""
|
@@ -231,13 +228,13 @@ class InstanceNorm:
|
|
231
228
|
"""
|
232
229
|
def __init__(self, num_features:int, eps=1e-5, affine=True):
|
233
230
|
self.num_features, self.eps = num_features, eps
|
234
|
-
self.weight:
|
235
|
-
self.bias:
|
231
|
+
self.weight: Tensor|None = Tensor.ones(num_features) if affine else None
|
232
|
+
self.bias: Tensor|None = Tensor.zeros(num_features) if affine else None
|
236
233
|
|
237
234
|
def __call__(self, x:Tensor) -> Tensor:
|
238
235
|
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
|
239
236
|
if self.weight is None or self.bias is None: return x
|
240
|
-
return x * self.weight.reshape(1, -1, *[1] * (
|
237
|
+
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
|
241
238
|
|
242
239
|
class LayerNorm:
|
243
240
|
"""
|
@@ -256,10 +253,11 @@ class LayerNorm:
|
|
256
253
|
print(t.mean().item(), t.std().item())
|
257
254
|
```
|
258
255
|
"""
|
259
|
-
def __init__(self, normalized_shape:
|
260
|
-
self.normalized_shape:
|
256
|
+
def __init__(self, normalized_shape:int|tuple[int, ...], eps=1e-5, elementwise_affine=True):
|
257
|
+
self.normalized_shape: tuple[int, ...] = make_tuple(normalized_shape, 1)
|
261
258
|
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
262
|
-
self.weight
|
259
|
+
self.weight: Tensor|None = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
|
260
|
+
self.bias: Tensor|None = Tensor.zeros(*self.normalized_shape) if elementwise_affine else None
|
263
261
|
|
264
262
|
def __call__(self, x:Tensor) -> Tensor:
|
265
263
|
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
@@ -322,10 +320,9 @@ class Embedding:
|
|
322
320
|
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
|
323
321
|
|
324
322
|
def __call__(self, idx:Tensor) -> Tensor:
|
325
|
-
if
|
326
|
-
|
327
|
-
|
328
|
-
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
|
323
|
+
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
|
324
|
+
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
|
325
|
+
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
|
329
326
|
return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
|
330
327
|
|
331
328
|
class LSTMCell:
|
@@ -341,9 +338,10 @@ class LSTMCell:
|
|
341
338
|
stdv = 1.0 / math.sqrt(hidden_size)
|
342
339
|
self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
|
343
340
|
self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
|
344
|
-
self.bias_ih
|
341
|
+
self.bias_ih: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
|
342
|
+
self.bias_hh: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
|
345
343
|
|
346
|
-
def __call__(self, x:Tensor, hc:
|
344
|
+
def __call__(self, x:Tensor, hc:tuple[Tensor, Tensor]|None=None) -> tuple[Tensor, Tensor]:
|
347
345
|
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
|
348
346
|
gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
|
349
347
|
i, f, g, o = gates.chunk(4, dim=1)
|
tinygrad/nn/datasets.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
from tinygrad.tensor import Tensor
|
2
|
-
from tinygrad.helpers import fetch
|
3
2
|
from tinygrad.nn.state import tar_extract
|
4
3
|
|
5
4
|
def mnist(device=None, fashion=False):
|
@@ -9,7 +8,7 @@ def mnist(device=None, fashion=False):
|
|
9
8
|
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
|
10
9
|
|
11
10
|
def cifar(device=None):
|
12
|
-
tt = tar_extract(
|
11
|
+
tt = tar_extract(Tensor.from_url('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
|
13
12
|
train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
|
14
13
|
test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
|
15
14
|
return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]
|
tinygrad/nn/optim.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
# sorted in order of increasing complexity
|
2
|
-
from
|
3
|
-
from tinygrad.helpers import dedup, flatten, getenv
|
2
|
+
from tinygrad.helpers import dedup, flatten, getenv, unwrap
|
4
3
|
from tinygrad.tensor import Tensor
|
5
4
|
from tinygrad.dtype import dtypes, least_upper_dtype
|
6
5
|
|
@@ -8,15 +7,15 @@ class Optimizer:
|
|
8
7
|
"""
|
9
8
|
Base class for all optimizers.
|
10
9
|
"""
|
11
|
-
def __init__(self, params:
|
10
|
+
def __init__(self, params: list[Tensor], lr: float):
|
12
11
|
# if it's None, but being put into an optimizer, set it to True
|
13
12
|
for x in params:
|
14
13
|
if x.requires_grad is None: x.requires_grad = True
|
15
14
|
|
16
|
-
self.params:
|
15
|
+
self.params: list[Tensor] = dedup([x for x in params if x.requires_grad])
|
17
16
|
assert len(self.params) != 0, "optimizer must have at least one param"
|
18
17
|
self.device = self.params[0].device
|
19
|
-
self.buffers:
|
18
|
+
self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
20
19
|
# store lr in at least float32 precision
|
21
20
|
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
|
22
21
|
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
|
@@ -32,15 +31,17 @@ class Optimizer:
|
|
32
31
|
Performs a single optimization step.
|
33
32
|
"""
|
34
33
|
Tensor.realize(*self.schedule_step())
|
35
|
-
|
34
|
+
|
35
|
+
def schedule_step(self) -> list[Tensor]:
|
36
36
|
"""
|
37
37
|
Returns the tensors that need to be realized to perform a single optimization step.
|
38
38
|
"""
|
39
|
-
|
39
|
+
if not Tensor.training: raise RuntimeError(
|
40
40
|
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
|
41
41
|
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
|
42
|
-
return self.
|
43
|
-
|
42
|
+
return self.schedule_step_with_grads([unwrap(t.grad) for t in self.params])+self.params+self.buffers
|
43
|
+
|
44
|
+
def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]: raise NotImplementedError
|
44
45
|
|
45
46
|
class OptimizerGroup(Optimizer):
|
46
47
|
"""
|
@@ -51,10 +52,10 @@ class OptimizerGroup(Optimizer):
|
|
51
52
|
self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
|
52
53
|
def __getitem__(self, i): return self.optimizers[i]
|
53
54
|
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
|
54
|
-
def
|
55
|
+
def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()]
|
55
56
|
|
56
57
|
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
|
57
|
-
def SGD(params:
|
58
|
+
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
|
58
59
|
"""
|
59
60
|
Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
|
60
61
|
|
@@ -71,17 +72,13 @@ class LARS(Optimizer):
|
|
71
72
|
- Described: https://paperswithcode.com/method/lars
|
72
73
|
- Paper: https://arxiv.org/abs/1708.03888v3
|
73
74
|
"""
|
74
|
-
def __init__(self, params:
|
75
|
+
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
|
75
76
|
super().__init__(params, lr)
|
76
77
|
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
|
77
78
|
self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
78
79
|
|
79
|
-
def
|
80
|
-
for i, t in enumerate(self.params):
|
81
|
-
assert t.grad is not None
|
82
|
-
# contiguous is needed since the grads can allegedly form a "diamond"
|
83
|
-
# TODO: fix this in lazy.py
|
84
|
-
g = t.grad.contiguous()
|
80
|
+
def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]:
|
81
|
+
for i, (t, g) in enumerate(zip(self.params, grads)):
|
85
82
|
if self.tcoef != 0:
|
86
83
|
r1 = t.detach().square().sum().sqrt()
|
87
84
|
r2 = g.square().sum().sqrt()
|
@@ -99,7 +96,7 @@ class LARS(Optimizer):
|
|
99
96
|
return self.b
|
100
97
|
|
101
98
|
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
|
102
|
-
def AdamW(params:
|
99
|
+
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
|
103
100
|
"""
|
104
101
|
AdamW optimizer with optional weight decay.
|
105
102
|
|
@@ -107,7 +104,7 @@ def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_dec
|
|
107
104
|
- Paper: https://arxiv.org/abs/1711.05101v3
|
108
105
|
"""
|
109
106
|
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)
|
110
|
-
def Adam(params:
|
107
|
+
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
111
108
|
"""
|
112
109
|
Adam optimizer.
|
113
110
|
|
@@ -123,20 +120,19 @@ class LAMB(Optimizer):
|
|
123
120
|
- Described: https://paperswithcode.com/method/lamb
|
124
121
|
- Paper: https://arxiv.org/abs/1904.00962
|
125
122
|
"""
|
126
|
-
def __init__(self, params:
|
123
|
+
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
|
127
124
|
super().__init__(params, lr)
|
128
125
|
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
|
129
126
|
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
|
130
127
|
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
131
128
|
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
132
129
|
|
133
|
-
def
|
130
|
+
def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]:
|
134
131
|
self.b1_t *= self.b1
|
135
132
|
self.b2_t *= self.b2
|
136
|
-
for i, t in enumerate(self.params):
|
137
|
-
|
138
|
-
self.
|
139
|
-
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
133
|
+
for i, (t, g) in enumerate(zip(self.params, grads)):
|
134
|
+
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g)
|
135
|
+
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g))
|
140
136
|
m_hat = self.m[i] / (1.0 - self.b1_t)
|
141
137
|
v_hat = self.v[i] / (1.0 - self.b2_t)
|
142
138
|
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
tinygrad/nn/state.py
CHANGED
@@ -1,24 +1,54 @@
|
|
1
|
-
import
|
2
|
-
from
|
1
|
+
import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
|
2
|
+
from collections import OrderedDict
|
3
|
+
from typing import Union, Optional, Any, Callable, BinaryIO, Iterable
|
3
4
|
from tinygrad.tensor import Tensor
|
4
5
|
from tinygrad.dtype import dtypes
|
5
|
-
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
|
6
|
+
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
|
6
7
|
from tinygrad.shape.view import strides_for_shape
|
7
|
-
|
8
|
+
|
9
|
+
class TensorIO(io.RawIOBase, BinaryIO):
|
10
|
+
def __init__(self, t: Tensor):
|
11
|
+
if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
|
12
|
+
self._position, self._tensor = 0, t
|
13
|
+
|
14
|
+
def readable(self) -> bool: return True
|
15
|
+
def read(self, size: int = -1) -> bytes:
|
16
|
+
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never)
|
17
|
+
return buf
|
18
|
+
def readinto(self, buffer: Any) -> int:
|
19
|
+
data = self._tensor[self._position:self._position+len(buffer)].data()
|
20
|
+
buffer[:len(data)] = data
|
21
|
+
self._position += len(data)
|
22
|
+
return len(data)
|
23
|
+
|
24
|
+
def seekable(self) -> bool: return True
|
25
|
+
def seek(self, offset: int, whence: int = 0) -> int:
|
26
|
+
self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence]))
|
27
|
+
return self._position
|
28
|
+
|
29
|
+
# required to correctly implement BinaryIO
|
30
|
+
def __enter__(self): return self
|
31
|
+
def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported")
|
32
|
+
def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported")
|
8
33
|
|
9
34
|
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
|
10
35
|
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
|
11
36
|
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
12
37
|
|
13
|
-
def
|
38
|
+
def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Union[Tensor, str, pathlib.Path]], T]:
|
39
|
+
@functools.wraps(func)
|
40
|
+
def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
|
41
|
+
return wrapper
|
42
|
+
|
43
|
+
@accept_filename
|
44
|
+
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
|
14
45
|
"""
|
15
46
|
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
|
16
47
|
"""
|
17
|
-
|
18
|
-
|
19
|
-
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
|
48
|
+
data_start = int.from_bytes(t[0:8].data(), "little") + 8
|
49
|
+
return t, data_start, json.loads(t[8:data_start].data().tobytes())
|
20
50
|
|
21
|
-
def safe_load(fn:Union[Tensor,str]) ->
|
51
|
+
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
|
22
52
|
"""
|
23
53
|
Loads a .safetensor file from disk, returning the state_dict.
|
24
54
|
|
@@ -26,16 +56,12 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
|
26
56
|
state_dict = nn.state.safe_load("test.safetensor")
|
27
57
|
```
|
28
58
|
"""
|
29
|
-
t,
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
|
36
|
-
return ret
|
37
|
-
|
38
|
-
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
59
|
+
t, data_start, metadata = safe_load_metadata(fn)
|
60
|
+
data = t[data_start:]
|
61
|
+
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
|
62
|
+
for k, v in metadata.items() if k != "__metadata__" }
|
63
|
+
|
64
|
+
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any]]=None):
|
39
65
|
"""
|
40
66
|
Saves a state_dict to disk in a .safetensor file with optional metadata.
|
41
67
|
|
@@ -50,7 +76,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
|
|
50
76
|
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
51
77
|
offset += v.nbytes()
|
52
78
|
j = json.dumps(headers, separators=(',', ':'))
|
53
|
-
j += "\x20"*((8-len(j)
|
79
|
+
j += "\x20"*(round_up(len(j),8)-len(j))
|
54
80
|
pathlib.Path(fn).unlink(missing_ok=True)
|
55
81
|
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
56
82
|
t[0:8].bitcast(dtypes.int64).assign([len(j)])
|
@@ -59,8 +85,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
|
|
59
85
|
|
60
86
|
# state dict
|
61
87
|
|
62
|
-
|
63
|
-
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
88
|
+
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
|
64
89
|
"""
|
65
90
|
Returns a state_dict of the object, with optional prefix.
|
66
91
|
|
@@ -84,7 +109,8 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
|
84
109
|
elif isinstance(obj, dict):
|
85
110
|
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
86
111
|
return state_dict
|
87
|
-
|
112
|
+
|
113
|
+
def get_parameters(obj) -> list[Tensor]:
|
88
114
|
"""
|
89
115
|
```python exec="true" source="above" session="tensor" result="python"
|
90
116
|
class Net:
|
@@ -98,7 +124,7 @@ def get_parameters(obj) -> List[Tensor]:
|
|
98
124
|
"""
|
99
125
|
return list(get_state_dict(obj).values())
|
100
126
|
|
101
|
-
def load_state_dict(model, state_dict:
|
127
|
+
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
|
102
128
|
"""
|
103
129
|
Loads a state_dict into a model.
|
104
130
|
|
@@ -114,7 +140,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
|
114
140
|
```
|
115
141
|
"""
|
116
142
|
start_mem_used = GlobalCounters.mem_used
|
117
|
-
with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {
|
143
|
+
with Timing("loaded weights in ", lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s"):
|
118
144
|
model_state_dict = get_state_dict(model)
|
119
145
|
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
120
146
|
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
@@ -123,27 +149,30 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
|
123
149
|
if k not in state_dict and not strict:
|
124
150
|
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
125
151
|
continue
|
126
|
-
if
|
127
|
-
|
128
|
-
|
152
|
+
if v.shape != state_dict[k].shape:
|
153
|
+
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
|
154
|
+
if isinstance(v.device, tuple):
|
155
|
+
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize()
|
156
|
+
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize()
|
129
157
|
else: v.replace(state_dict[k].to(v.device)).realize()
|
130
158
|
if consume: del state_dict[k]
|
131
159
|
|
132
|
-
|
160
|
+
@accept_filename
|
161
|
+
def tar_extract(t: Tensor) -> dict[str, Tensor]:
|
133
162
|
"""
|
134
163
|
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
|
135
164
|
|
136
165
|
```python
|
137
|
-
tensors = nn.state.tar_extract("archive.tar")
|
166
|
+
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
|
138
167
|
```
|
139
168
|
"""
|
140
|
-
|
141
|
-
with tarfile.open(fn, "r") as tar:
|
169
|
+
with tarfile.open(fileobj=TensorIO(t), mode="r") as tar:
|
142
170
|
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
|
143
171
|
|
144
172
|
# torch support!
|
145
173
|
|
146
|
-
|
174
|
+
@accept_filename
|
175
|
+
def torch_load(t:Tensor) -> dict[str, Tensor]:
|
147
176
|
"""
|
148
177
|
Loads a torch .pth file from disk.
|
149
178
|
|
@@ -151,10 +180,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
151
180
|
state_dict = nn.state.torch_load("test.pth")
|
152
181
|
```
|
153
182
|
"""
|
154
|
-
|
155
|
-
|
156
|
-
offsets: Dict[Union[str, int], int] = {}
|
157
|
-
lens: Dict[Union[str, int], int] = {}
|
183
|
+
offsets: dict[Union[str, int], int] = {}
|
184
|
+
lens: dict[Union[str, int], int] = {}
|
158
185
|
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
|
159
186
|
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
160
187
|
lens[storage[2]] = storage[4] * storage[1].itemsize
|
@@ -168,8 +195,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
168
195
|
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
169
196
|
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
170
197
|
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
171
|
-
if DEBUG >= 3: print(f"WARNING: this torch load is slow.
|
172
|
-
assert storage[1] != dtypes.bfloat16, "can't
|
198
|
+
if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
|
199
|
+
assert storage[1] != dtypes.bfloat16, "can't permute BF16"
|
173
200
|
# TODO: find a nice way to support all shapetracker on disktensors
|
174
201
|
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
|
175
202
|
|
@@ -178,7 +205,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
178
205
|
class Parameter:
|
179
206
|
def __setstate__(self, state): self.tensor = state[0]
|
180
207
|
|
181
|
-
deserialized_objects:
|
208
|
+
deserialized_objects: dict[str, Any] = {}
|
182
209
|
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
|
183
210
|
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
|
184
211
|
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
@@ -193,8 +220,11 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
193
220
|
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
194
221
|
def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
|
195
222
|
|
196
|
-
|
197
|
-
|
223
|
+
fobj = io.BufferedReader(TensorIO(t))
|
224
|
+
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
|
225
|
+
|
226
|
+
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
|
227
|
+
myzip = zipfile.ZipFile(fobj, 'r')
|
198
228
|
base_name = myzip.namelist()[0].split('/', 1)[0]
|
199
229
|
for n in myzip.namelist():
|
200
230
|
if n.startswith(f'{base_name}/data/'):
|
@@ -202,8 +232,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
202
232
|
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
203
233
|
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
204
234
|
return TorchPickle(myfile).load()
|
205
|
-
elif tarfile.is_tarfile(
|
206
|
-
with tarfile.open(
|
235
|
+
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
|
236
|
+
with tarfile.open(fileobj=fobj, mode="r") as tar:
|
207
237
|
storages_offset = tar.getmember('storages').offset_data
|
208
238
|
f = unwrap(tar.extractfile('storages'))
|
209
239
|
for i in range(TorchPickle(f).load()): # num_storages
|
@@ -218,14 +248,13 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|
218
248
|
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
|
219
249
|
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
|
220
250
|
else:
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
return TorchPickle(f).load()
|
251
|
+
pkl = TorchPickle(fobj)
|
252
|
+
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
|
253
|
+
for i in ids:
|
254
|
+
offsets[i] = base_offset + 8
|
255
|
+
base_offset += 8 + lens[i]
|
256
|
+
fobj.seek(rwd)
|
257
|
+
return TorchPickle(fobj).load()
|
229
258
|
|
230
259
|
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
231
260
|
"""
|
@@ -260,7 +289,8 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
|
260
289
|
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
|
261
290
|
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
|
262
291
|
|
263
|
-
|
292
|
+
@accept_filename
|
293
|
+
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
|
264
294
|
"""
|
265
295
|
Loads a gguf file from a tensor.
|
266
296
|
|
@@ -270,31 +300,26 @@ def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
|
|
270
300
|
kv_data, state_dict = gguf_load(gguf_tensor)
|
271
301
|
```
|
272
302
|
"""
|
273
|
-
|
274
|
-
|
275
|
-
def
|
276
|
-
nonlocal pos, read_buffer, rb_start
|
277
|
-
if rb_start + len(read_buffer) < pos + n: rb_start, read_buffer = pos, tensor[pos:(pos+max(n, 1000_000))].data()
|
278
|
-
return read_buffer[pos-rb_start:(pos:=pos+n)-rb_start]
|
279
|
-
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, read_bytes(n))[0]
|
280
|
-
def read_str(): return str(read_bytes(read_uint64()), "utf-8")
|
303
|
+
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
|
304
|
+
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
|
305
|
+
def read_str(): return str(reader.read(read_uint64()), "utf-8")
|
281
306
|
def read_arr():
|
282
307
|
reader, n = readers[read_int32()], read_uint64()
|
283
308
|
return [ reader() for _ in range(n) ]
|
284
309
|
|
285
|
-
readers:
|
286
|
-
(1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
|
310
|
+
readers: dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \
|
311
|
+
[ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
|
287
312
|
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
|
288
313
|
|
289
|
-
magic, version, n_tensors, n_kv =
|
314
|
+
magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64()
|
290
315
|
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
|
291
316
|
for _ in range(n_kv):
|
292
317
|
k, typ = read_str(), read_int32()
|
293
318
|
kv_data[k] = readers[typ]()
|
294
319
|
|
295
320
|
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
|
296
|
-
alignment = kv_data.get("general.alignment", 32)
|
297
|
-
data_start =
|
321
|
+
alignment, pos = kv_data.get("general.alignment", 32), reader.tell()
|
322
|
+
data_start = round_up(pos, alignment)
|
298
323
|
|
299
324
|
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
|
300
325
|
|