tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/nn/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import math
3
- from typing import Optional, Union, Tuple, List
4
- from tinygrad.tensor import Tensor, dtypes
3
+ from tinygrad.tensor import Tensor
4
+ from tinygrad.dtype import dtypes
5
5
  from tinygrad.device import is_dtype_supported
6
6
  from tinygrad.helpers import prod, make_tuple, flatten
7
7
  from tinygrad.nn import optim, state, datasets # noqa: F401
@@ -34,14 +34,14 @@ class BatchNorm:
34
34
  def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
35
35
  self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
36
36
 
37
- self.weight: Optional[Tensor] = Tensor.ones(sz) if affine else None
38
- self.bias: Optional[Tensor] = Tensor.zeros(sz) if affine else None
37
+ self.weight: Tensor|None = Tensor.ones(sz) if affine else None
38
+ self.bias: Tensor|None = Tensor.zeros(sz) if affine else None
39
39
 
40
40
  self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
41
41
  if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
42
42
 
43
- def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]:
44
- shape_mask: List[int] = [1, -1, *([1]*(x.ndim-2))]
43
+ def calc_stats(self, x:Tensor) -> tuple[Tensor, Tensor]:
44
+ shape_mask: list[int] = [1, -1, *([1]*(x.ndim-2))]
45
45
  if self.track_running_stats and not Tensor.training: return self.running_mean, self.running_var.reshape(shape=shape_mask).expand(x.shape)
46
46
  # This requires two full memory accesses to x
47
47
  # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
@@ -56,12 +56,12 @@ class BatchNorm:
56
56
  # NOTE: wow, this is done all throughout training in most PyTorch models
57
57
  if self.track_running_stats and Tensor.training:
58
58
  self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
59
- self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(x.shape)/(prod(x.shape)-x.shape[1]) * batch_var.detach())
59
+ self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * x.numel()/(x.numel()-x.shape[1]) * batch_var.detach())
60
60
  self.num_batches_tracked += 1
61
61
  return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
62
62
  BatchNorm2d = BatchNorm3d = BatchNorm
63
63
 
64
- def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:Union[int, str]=0, dilation=1, groups=1, bias=True) -> Conv2d:
64
+ def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:int|str=0, dilation=1, groups=1, bias=True) -> Conv2d:
65
65
  """
66
66
  Applies a 1D convolution over an input signal composed of several input planes.
67
67
 
@@ -95,7 +95,7 @@ class Conv2d:
95
95
  print(t.numpy())
96
96
  ```
97
97
  """
98
- def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding:Union[int, Tuple[int, ...], str]=0,
98
+ def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride=1, padding:int|tuple[int, ...]|str=0,
99
99
  dilation=1, groups=1, bias=True):
100
100
  self.kernel_size = make_tuple(kernel_size, 2)
101
101
  if isinstance(padding, str):
@@ -106,10 +106,9 @@ class Conv2d:
106
106
  self.stride, self.dilation, self.groups, self.padding = stride, dilation, groups, padding
107
107
  scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
108
108
  self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
109
- self.bias: Optional[Tensor] = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
109
+ self.bias: Tensor|None = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
110
110
 
111
- def __call__(self, x:Tensor) -> Tensor:
112
- return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
111
+ def __call__(self, x:Tensor) -> Tensor: return x.conv2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding)
113
112
 
114
113
  def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
115
114
  groups=1, bias=True) -> ConvTranspose2d:
@@ -146,7 +145,7 @@ class ConvTranspose2d(Conv2d):
146
145
  print(t.numpy())
147
146
  ```
148
147
  """
149
- def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding=0, output_padding=0,
148
+ def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride=1, padding=0, output_padding=0,
150
149
  dilation=1, groups=1, bias=True):
151
150
  super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
152
151
  scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
@@ -154,8 +153,7 @@ class ConvTranspose2d(Conv2d):
154
153
  self.output_padding = output_padding
155
154
 
156
155
  def __call__(self, x:Tensor) -> Tensor:
157
- return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
158
- dilation=self.dilation, groups=self.groups)
156
+ return x.conv_transpose2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding, self.output_padding)
159
157
 
160
158
  class Linear:
161
159
  """
@@ -178,8 +176,7 @@ class Linear:
178
176
  self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
179
177
  self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
180
178
 
181
- def __call__(self, x:Tensor) -> Tensor:
182
- return x.linear(self.weight.transpose(), self.bias)
179
+ def __call__(self, x:Tensor) -> Tensor: return x.linear(self.weight.transpose(), self.bias)
183
180
 
184
181
  class GroupNorm:
185
182
  """
@@ -200,8 +197,8 @@ class GroupNorm:
200
197
  """
201
198
  def __init__(self, num_groups:int, num_channels:int, eps=1e-5, affine=True):
202
199
  self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
203
- self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
204
- self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
200
+ self.weight: Tensor|None = Tensor.ones(num_channels) if affine else None
201
+ self.bias: Tensor|None = Tensor.zeros(num_channels) if affine else None
205
202
 
206
203
  def __call__(self, x:Tensor) -> Tensor:
207
204
  # reshape for layernorm to work as group norm
@@ -210,7 +207,7 @@ class GroupNorm:
210
207
 
211
208
  if self.weight is None or self.bias is None: return x
212
209
  # elementwise_affine on channels
213
- return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
210
+ return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
214
211
 
215
212
  class InstanceNorm:
216
213
  """
@@ -231,13 +228,13 @@ class InstanceNorm:
231
228
  """
232
229
  def __init__(self, num_features:int, eps=1e-5, affine=True):
233
230
  self.num_features, self.eps = num_features, eps
234
- self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
235
- self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
231
+ self.weight: Tensor|None = Tensor.ones(num_features) if affine else None
232
+ self.bias: Tensor|None = Tensor.zeros(num_features) if affine else None
236
233
 
237
234
  def __call__(self, x:Tensor) -> Tensor:
238
235
  x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
239
236
  if self.weight is None or self.bias is None: return x
240
- return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
237
+ return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
241
238
 
242
239
  class LayerNorm:
243
240
  """
@@ -256,10 +253,11 @@ class LayerNorm:
256
253
  print(t.mean().item(), t.std().item())
257
254
  ```
258
255
  """
259
- def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
260
- self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
256
+ def __init__(self, normalized_shape:int|tuple[int, ...], eps=1e-5, elementwise_affine=True):
257
+ self.normalized_shape: tuple[int, ...] = make_tuple(normalized_shape, 1)
261
258
  self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
262
- self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
259
+ self.weight: Tensor|None = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
260
+ self.bias: Tensor|None = Tensor.zeros(*self.normalized_shape) if elementwise_affine else None
263
261
 
264
262
  def __call__(self, x:Tensor) -> Tensor:
265
263
  assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
@@ -322,10 +320,9 @@ class Embedding:
322
320
  self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
323
321
 
324
322
  def __call__(self, idx:Tensor) -> Tensor:
325
- if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
326
- arange_shp, weight_shp, big_shp = (self.vocab_sz, 1), (self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
327
- if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
328
- arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
323
+ if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
324
+ big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
325
+ arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
329
326
  return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
330
327
 
331
328
  class LSTMCell:
@@ -341,9 +338,10 @@ class LSTMCell:
341
338
  stdv = 1.0 / math.sqrt(hidden_size)
342
339
  self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
343
340
  self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
344
- self.bias_ih, self.bias_hh = (Tensor.zeros(hidden_size*4), Tensor.zeros(hidden_size*4)) if bias else (None, None)
341
+ self.bias_ih: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
342
+ self.bias_hh: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
345
343
 
346
- def __call__(self, x:Tensor, hc:Optional[Tuple[Tensor, Tensor]]=None) -> Tuple[Tensor, Tensor]:
344
+ def __call__(self, x:Tensor, hc:tuple[Tensor, Tensor]|None=None) -> tuple[Tensor, Tensor]:
347
345
  if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
348
346
  gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
349
347
  i, f, g, o = gates.chunk(4, dim=1)
tinygrad/nn/datasets.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from tinygrad.tensor import Tensor
2
- from tinygrad.helpers import fetch
3
2
  from tinygrad.nn.state import tar_extract
4
3
 
5
4
  def mnist(device=None, fashion=False):
@@ -9,7 +8,7 @@ def mnist(device=None, fashion=False):
9
8
  _mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
10
9
 
11
10
  def cifar(device=None):
12
- tt = tar_extract(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
11
+ tt = tar_extract(Tensor.from_url('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
13
12
  train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
14
13
  test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
15
14
  return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]
tinygrad/nn/optim.py CHANGED
@@ -1,6 +1,5 @@
1
1
  # sorted in order of increasing complexity
2
- from typing import List
3
- from tinygrad.helpers import dedup, flatten, getenv
2
+ from tinygrad.helpers import dedup, flatten, getenv, unwrap
4
3
  from tinygrad.tensor import Tensor
5
4
  from tinygrad.dtype import dtypes, least_upper_dtype
6
5
 
@@ -8,15 +7,15 @@ class Optimizer:
8
7
  """
9
8
  Base class for all optimizers.
10
9
  """
11
- def __init__(self, params: List[Tensor], lr: float):
10
+ def __init__(self, params: list[Tensor], lr: float):
12
11
  # if it's None, but being put into an optimizer, set it to True
13
12
  for x in params:
14
13
  if x.requires_grad is None: x.requires_grad = True
15
14
 
16
- self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
15
+ self.params: list[Tensor] = dedup([x for x in params if x.requires_grad])
17
16
  assert len(self.params) != 0, "optimizer must have at least one param"
18
17
  self.device = self.params[0].device
19
- self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
18
+ self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
20
19
  # store lr in at least float32 precision
21
20
  self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
22
21
  dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
@@ -32,15 +31,17 @@ class Optimizer:
32
31
  Performs a single optimization step.
33
32
  """
34
33
  Tensor.realize(*self.schedule_step())
35
- def schedule_step(self) -> List[Tensor]:
34
+
35
+ def schedule_step(self) -> list[Tensor]:
36
36
  """
37
37
  Returns the tensors that need to be realized to perform a single optimization step.
38
38
  """
39
- assert Tensor.training, (
39
+ if not Tensor.training: raise RuntimeError(
40
40
  f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
41
41
  - help: Consider setting Tensor.training=True before calling Optimizer.step().""")
42
- return self._step()+self.params+self.buffers
43
- def _step(self) -> List[Tensor]: raise NotImplementedError
42
+ return self.schedule_step_with_grads([unwrap(t.grad) for t in self.params])+self.params+self.buffers
43
+
44
+ def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]: raise NotImplementedError
44
45
 
45
46
  class OptimizerGroup(Optimizer):
46
47
  """
@@ -51,10 +52,10 @@ class OptimizerGroup(Optimizer):
51
52
  self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
52
53
  def __getitem__(self, i): return self.optimizers[i]
53
54
  def zero_grad(self): [o.zero_grad() for o in self.optimizers]
54
- def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]
55
+ def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()]
55
56
 
56
57
  # LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
57
- def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
58
+ def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
58
59
  """
59
60
  Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
60
61
 
@@ -71,17 +72,13 @@ class LARS(Optimizer):
71
72
  - Described: https://paperswithcode.com/method/lars
72
73
  - Paper: https://arxiv.org/abs/1708.03888v3
73
74
  """
74
- def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
75
+ def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
75
76
  super().__init__(params, lr)
76
77
  self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
77
78
  self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
78
79
 
79
- def _step(self) -> List[Tensor]:
80
- for i, t in enumerate(self.params):
81
- assert t.grad is not None
82
- # contiguous is needed since the grads can allegedly form a "diamond"
83
- # TODO: fix this in lazy.py
84
- g = t.grad.contiguous()
80
+ def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]:
81
+ for i, (t, g) in enumerate(zip(self.params, grads)):
85
82
  if self.tcoef != 0:
86
83
  r1 = t.detach().square().sum().sqrt()
87
84
  r2 = g.square().sum().sqrt()
@@ -99,7 +96,7 @@ class LARS(Optimizer):
99
96
  return self.b
100
97
 
101
98
  # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
102
- def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
99
+ def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
103
100
  """
104
101
  AdamW optimizer with optional weight decay.
105
102
 
@@ -107,7 +104,7 @@ def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_dec
107
104
  - Paper: https://arxiv.org/abs/1711.05101v3
108
105
  """
109
106
  return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)
110
- def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
107
+ def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
111
108
  """
112
109
  Adam optimizer.
113
110
 
@@ -123,20 +120,19 @@ class LAMB(Optimizer):
123
120
  - Described: https://paperswithcode.com/method/lamb
124
121
  - Paper: https://arxiv.org/abs/1904.00962
125
122
  """
126
- def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
123
+ def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
127
124
  super().__init__(params, lr)
128
125
  self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
129
126
  self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
130
127
  self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
131
128
  self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
132
129
 
133
- def _step(self) -> List[Tensor]:
130
+ def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]:
134
131
  self.b1_t *= self.b1
135
132
  self.b2_t *= self.b2
136
- for i, t in enumerate(self.params):
137
- assert t.grad is not None
138
- self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
139
- self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
133
+ for i, (t, g) in enumerate(zip(self.params, grads)):
134
+ self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g)
135
+ self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g))
140
136
  m_hat = self.m[i] / (1.0 - self.b1_t)
141
137
  v_hat = self.v[i] / (1.0 - self.b2_t)
142
138
  up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
tinygrad/nn/state.py CHANGED
@@ -1,24 +1,54 @@
1
- import os, json, pathlib, zipfile, pickle, tarfile, struct, functools
2
- from typing import Dict, Union, List, Optional, Any, Tuple, Callable
1
+ import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
2
+ from collections import OrderedDict
3
+ from typing import Union, Optional, Any, Callable, BinaryIO, Iterable
3
4
  from tinygrad.tensor import Tensor
4
5
  from tinygrad.dtype import dtypes
5
- from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
6
+ from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
6
7
  from tinygrad.shape.view import strides_for_shape
7
- from tinygrad.multi import MultiLazyBuffer
8
+
9
+ class TensorIO(io.RawIOBase, BinaryIO):
10
+ def __init__(self, t: Tensor):
11
+ if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
12
+ self._position, self._tensor = 0, t
13
+
14
+ def readable(self) -> bool: return True
15
+ def read(self, size: int = -1) -> bytes:
16
+ if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never)
17
+ return buf
18
+ def readinto(self, buffer: Any) -> int:
19
+ data = self._tensor[self._position:self._position+len(buffer)].data()
20
+ buffer[:len(data)] = data
21
+ self._position += len(data)
22
+ return len(data)
23
+
24
+ def seekable(self) -> bool: return True
25
+ def seek(self, offset: int, whence: int = 0) -> int:
26
+ self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence]))
27
+ return self._position
28
+
29
+ # required to correctly implement BinaryIO
30
+ def __enter__(self): return self
31
+ def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported")
32
+ def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported")
8
33
 
9
34
  safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
10
35
  "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
11
36
  inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
12
37
 
13
- def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
38
+ def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Union[Tensor, str, pathlib.Path]], T]:
39
+ @functools.wraps(func)
40
+ def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
41
+ return wrapper
42
+
43
+ @accept_filename
44
+ def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
14
45
  """
15
46
  Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
16
47
  """
17
- t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
18
- json_len = t[0:8].bitcast(dtypes.int64).item()
19
- return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
48
+ data_start = int.from_bytes(t[0:8].data(), "little") + 8
49
+ return t, data_start, json.loads(t[8:data_start].data().tobytes())
20
50
 
21
- def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
51
+ def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
22
52
  """
23
53
  Loads a .safetensor file from disk, returning the state_dict.
24
54
 
@@ -26,16 +56,12 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
26
56
  state_dict = nn.state.safe_load("test.safetensor")
27
57
  ```
28
58
  """
29
- t, json_len, metadata = safe_load_metadata(fn)
30
- ret = {}
31
- for k,v in metadata.items():
32
- if k == "__metadata__": continue
33
- dtype = safe_dtypes[v['dtype']]
34
- sz = (v['data_offsets'][1]-v['data_offsets'][0])
35
- ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
36
- return ret
37
-
38
- def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
59
+ t, data_start, metadata = safe_load_metadata(fn)
60
+ data = t[data_start:]
61
+ return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
62
+ for k, v in metadata.items() if k != "__metadata__" }
63
+
64
+ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any]]=None):
39
65
  """
40
66
  Saves a state_dict to disk in a .safetensor file with optional metadata.
41
67
 
@@ -50,7 +76,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
50
76
  headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
51
77
  offset += v.nbytes()
52
78
  j = json.dumps(headers, separators=(',', ':'))
53
- j += "\x20"*((8-len(j)%8)%8)
79
+ j += "\x20"*(round_up(len(j),8)-len(j))
54
80
  pathlib.Path(fn).unlink(missing_ok=True)
55
81
  t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
56
82
  t[0:8].bitcast(dtypes.int64).assign([len(j)])
@@ -59,8 +85,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
59
85
 
60
86
  # state dict
61
87
 
62
- from collections import OrderedDict
63
- def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
88
+ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
64
89
  """
65
90
  Returns a state_dict of the object, with optional prefix.
66
91
 
@@ -84,7 +109,8 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
84
109
  elif isinstance(obj, dict):
85
110
  for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
86
111
  return state_dict
87
- def get_parameters(obj) -> List[Tensor]:
112
+
113
+ def get_parameters(obj) -> list[Tensor]:
88
114
  """
89
115
  ```python exec="true" source="above" session="tensor" result="python"
90
116
  class Net:
@@ -98,7 +124,7 @@ def get_parameters(obj) -> List[Tensor]:
98
124
  """
99
125
  return list(get_state_dict(obj).values())
100
126
 
101
- def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
127
+ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
102
128
  """
103
129
  Loads a state_dict into a model.
104
130
 
@@ -114,7 +140,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
114
140
  ```
115
141
  """
116
142
  start_mem_used = GlobalCounters.mem_used
117
- with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501
143
+ with Timing("loaded weights in ", lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s"):
118
144
  model_state_dict = get_state_dict(model)
119
145
  if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
120
146
  print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
@@ -123,27 +149,30 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
123
149
  if k not in state_dict and not strict:
124
150
  if DEBUG >= 1: print(f"WARNING: not loading {k}")
125
151
  continue
126
- if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
127
- if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
128
- else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
152
+ if v.shape != state_dict[k].shape:
153
+ raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
154
+ if isinstance(v.device, tuple):
155
+ if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize()
156
+ else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize()
129
157
  else: v.replace(state_dict[k].to(v.device)).realize()
130
158
  if consume: del state_dict[k]
131
159
 
132
- def tar_extract(fn:os.PathLike) -> Dict[str, Tensor]:
160
+ @accept_filename
161
+ def tar_extract(t: Tensor) -> dict[str, Tensor]:
133
162
  """
134
163
  Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
135
164
 
136
165
  ```python
137
- tensors = nn.state.tar_extract("archive.tar")
166
+ tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
138
167
  ```
139
168
  """
140
- t = Tensor(pathlib.Path(fn))
141
- with tarfile.open(fn, "r") as tar:
169
+ with tarfile.open(fileobj=TensorIO(t), mode="r") as tar:
142
170
  return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
143
171
 
144
172
  # torch support!
145
173
 
146
- def torch_load(fn:str) -> Dict[str, Tensor]:
174
+ @accept_filename
175
+ def torch_load(t:Tensor) -> dict[str, Tensor]:
147
176
  """
148
177
  Loads a torch .pth file from disk.
149
178
 
@@ -151,10 +180,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
151
180
  state_dict = nn.state.torch_load("test.pth")
152
181
  ```
153
182
  """
154
- t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
155
-
156
- offsets: Dict[Union[str, int], int] = {}
157
- lens: Dict[Union[str, int], int] = {}
183
+ offsets: dict[Union[str, int], int] = {}
184
+ lens: dict[Union[str, int], int] = {}
158
185
  def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
159
186
  #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
160
187
  lens[storage[2]] = storage[4] * storage[1].itemsize
@@ -178,7 +205,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
178
205
  class Parameter:
179
206
  def __setstate__(self, state): self.tensor = state[0]
180
207
 
181
- deserialized_objects: Dict[str, Any] = {}
208
+ deserialized_objects: dict[str, Any] = {}
182
209
  intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
183
210
  "IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
184
211
  "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
@@ -193,8 +220,11 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
193
220
  return intercept[name] if module_root == "torch" else super().find_class(module, name)
194
221
  def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
195
222
 
196
- if zipfile.is_zipfile(fn):
197
- myzip = zipfile.ZipFile(fn, 'r')
223
+ fobj = io.BufferedReader(TensorIO(t))
224
+ def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
225
+
226
+ if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
227
+ myzip = zipfile.ZipFile(fobj, 'r')
198
228
  base_name = myzip.namelist()[0].split('/', 1)[0]
199
229
  for n in myzip.namelist():
200
230
  if n.startswith(f'{base_name}/data/'):
@@ -202,8 +232,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
202
232
  offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
203
233
  with myzip.open(f'{base_name}/data.pkl') as myfile:
204
234
  return TorchPickle(myfile).load()
205
- elif tarfile.is_tarfile(fn):
206
- with tarfile.open(fn, "r") as tar:
235
+ elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
236
+ with tarfile.open(fileobj=fobj, mode="r") as tar:
207
237
  storages_offset = tar.getmember('storages').offset_data
208
238
  f = unwrap(tar.extractfile('storages'))
209
239
  for i in range(TorchPickle(f).load()): # num_storages
@@ -218,14 +248,13 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
218
248
  deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
219
249
  return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
220
250
  else:
221
- with open(fn, "rb") as f:
222
- pkl = TorchPickle(f)
223
- _, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
224
- for i in ids:
225
- offsets[i] = base_offset + 8
226
- base_offset += 8 + lens[i]
227
- f.seek(rwd)
228
- return TorchPickle(f).load()
251
+ pkl = TorchPickle(fobj)
252
+ _, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
253
+ for i in ids:
254
+ offsets[i] = base_offset + 8
255
+ base_offset += 8 + lens[i]
256
+ fobj.seek(rwd)
257
+ return TorchPickle(fobj).load()
229
258
 
230
259
  def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
231
260
  """
@@ -260,7 +289,8 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
260
289
  return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
261
290
  raise ValueError(f"GGML type '{ggml_type}' is not supported!")
262
291
 
263
- def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
292
+ @accept_filename
293
+ def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
264
294
  """
265
295
  Loads a gguf file from a tensor.
266
296
 
@@ -270,31 +300,26 @@ def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
270
300
  kv_data, state_dict = gguf_load(gguf_tensor)
271
301
  ```
272
302
  """
273
- if tensor.dtype != dtypes.uint8 or len(tensor.shape) != 1: raise ValueError("GGUF tensor must be 1d and of dtype uint8!")
274
- pos, read_buffer, rb_start, kv_data, state_dict = 0, memoryview(bytes()), 0, {}, {}
275
- def read_bytes(n: int):
276
- nonlocal pos, read_buffer, rb_start
277
- if rb_start + len(read_buffer) < pos + n: rb_start, read_buffer = pos, tensor[pos:(pos+max(n, 1000_000))].data()
278
- return read_buffer[pos-rb_start:(pos:=pos+n)-rb_start]
279
- def read_unpack(fmt: str, n: int): return struct.unpack(fmt, read_bytes(n))[0]
280
- def read_str(): return str(read_bytes(read_uint64()), "utf-8")
303
+ reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
304
+ def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
305
+ def read_str(): return str(reader.read(read_uint64()), "utf-8")
281
306
  def read_arr():
282
307
  reader, n = readers[read_int32()], read_uint64()
283
308
  return [ reader() for _ in range(n) ]
284
309
 
285
- readers: Dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t, f, nb in [ (0,"c",1),
286
- (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
310
+ readers: dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \
311
+ [ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
287
312
  read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
288
313
 
289
- magic, version, n_tensors, n_kv = read_bytes(4), read_int32(), read_int64(), read_int64()
314
+ magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64()
290
315
  if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
291
316
  for _ in range(n_kv):
292
317
  k, typ = read_str(), read_int32()
293
318
  kv_data[k] = readers[typ]()
294
319
 
295
320
  t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
296
- alignment = kv_data.get("general.alignment", 32)
297
- data_start = pos = pos + (alignment - pos % alignment if pos % alignment != 0 else 0)
321
+ alignment, pos = kv_data.get("general.alignment", 32), reader.tell()
322
+ data_start = round_up(pos, alignment)
298
323
 
299
324
  for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
300
325