tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/nn/optim.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # sorted in order of increasing complexity
2
- from tinygrad.helpers import dedup, flatten, getenv, unwrap
2
+ import itertools
3
+ from tinygrad.helpers import dedup, flatten, getenv, unwrap, FUSE_OPTIM
3
4
  from tinygrad.tensor import Tensor
4
5
  from tinygrad.dtype import dtypes, least_upper_dtype
5
6
 
@@ -7,7 +8,7 @@ class Optimizer:
7
8
  """
8
9
  Base class for all optimizers.
9
10
  """
10
- def __init__(self, params: list[Tensor], lr: float):
11
+ def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM):
11
12
  # if it's None, but being put into an optimizer, set it to True
12
13
  for x in params:
13
14
  if x.requires_grad is None: x.requires_grad = True
@@ -16,9 +17,16 @@ class Optimizer:
16
17
  assert len(self.params) != 0, "optimizer must have at least one param"
17
18
  self.device = self.params[0].device
18
19
  self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
20
+ self.fused = fused
19
21
  # store lr in at least float32 precision
20
22
  self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
21
23
  dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
24
+ if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
25
+
26
+ def _new_optim_param(self) -> list[Tensor]:
27
+ param_dtype = getenv("OPTIM_DTYPE", "float32")
28
+ if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
29
+ return [Tensor.zeros(*t.shape, dtype=param_dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
22
30
 
23
31
  def zero_grad(self):
24
32
  """
@@ -39,9 +47,17 @@ class Optimizer:
39
47
  if not Tensor.training: raise RuntimeError(
40
48
  f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
41
49
  - help: Consider setting Tensor.training=True before calling Optimizer.step().""")
42
- return self.schedule_step_with_grads([unwrap(t.grad) for t in self.params])+self.params+self.buffers
43
-
44
- def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]: raise NotImplementedError
50
+ if self.fused:
51
+ # optimizer fusion just concatenates all the buffers, runs the _step, then splits them back up
52
+ out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
53
+ [Tensor.cat(*[unwrap(t.grad).flatten() for t in self.params], dim=0)])
54
+ updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
55
+ else:
56
+ updated_params, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
57
+ for i, tt in enumerate(self.params): tt.assign(updated_params[i])
58
+ return extra+self.params+self.buffers
59
+
60
+ def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: raise NotImplementedError
45
61
 
46
62
  class OptimizerGroup(Optimizer):
47
63
  """
@@ -54,93 +70,108 @@ class OptimizerGroup(Optimizer):
54
70
  def zero_grad(self): [o.zero_grad() for o in self.optimizers]
55
71
  def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()]
56
72
 
57
- # LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
58
- def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
73
+ # LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 it's just standard SGD.
74
+ def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, fused=FUSE_OPTIM):
59
75
  """
60
76
  Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
61
77
 
62
78
  `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
79
+ """
80
+ return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, fused=fused)
63
81
 
64
- - Described: https://paperswithcode.com/method/sgd
82
+ # Muon applies the newton schulz algorithm on gradient. also can include momentum, nesterov, and weight decay
83
+ def Muon(params: list[Tensor], lr=0.02, momentum=0.95, weight_decay=0.0, ns_steps=5, ns_params=(3.4445, -4.775, 2.0315),
84
+ nesterov=True, fused=FUSE_OPTIM):
65
85
  """
66
- return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0)
86
+ SGD with newton-schulz iteration and post momentum weight decay.
87
+
88
+ - Described: https://kellerjordan.github.io/posts/muon/
89
+ - Paper: https://arxiv.org/pdf/2502.16982
90
+ """
91
+ assert not fused, "FUSE_OPTIM not allowed for Muon optimizer"
92
+ return LARS(params, lr, momentum, weight_decay, ns_steps, ns_params, nesterov, classic=False, pre_wd=False, tcoef=0.0, fused=fused)
67
93
 
68
94
  class LARS(Optimizer):
69
95
  """
70
96
  Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.
71
97
 
72
- - Described: https://paperswithcode.com/method/lars
73
98
  - Paper: https://arxiv.org/abs/1708.03888v3
74
99
  """
75
- def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
76
- super().__init__(params, lr)
77
- self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
78
- self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
79
-
80
- def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]:
81
- for i, (t, g) in enumerate(zip(self.params, grads)):
100
+ def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, ns_steps=0, ns_params=None,
101
+ nesterov=False, classic=True, pre_wd=True, tcoef=0.001, fused=FUSE_OPTIM):
102
+ super().__init__(params, lr, fused)
103
+ self.momentum, self.wd, self.ns_steps, self.ns_params = momentum, weight_decay, ns_steps, ns_params
104
+ self.nesterov, self.classic, self.pre_wd, self.tcoef = nesterov, classic, pre_wd, tcoef
105
+ self.b = self._new_optim_param() if self.momentum else []
106
+
107
+ def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
108
+ ret = []
109
+ for i, (t, g) in enumerate(zip(params, grads)):
82
110
  if self.tcoef != 0:
83
111
  r1 = t.detach().square().sum().sqrt()
84
112
  r2 = g.square().sum().sqrt()
85
- r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
113
+ r:Tensor|float = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
86
114
  else: r = 1.0
87
- g = g + self.wd * t.detach()
115
+ if self.pre_wd and self.wd > 0: g = g + self.wd * t.detach()
88
116
  # classic momentum does post learning rate update
89
117
  if self.classic: g = g * r * self.lr
90
118
  if self.momentum:
91
- self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
119
+ # TODO: this contiguous is required for correctness because self.b[i] becomes a non contiguous view
120
+ # the scheduler should detect this and just insert contiguous
121
+ self.b[i].assign(self.momentum * self.b[i].contiguous() + g) # NOTE: self.b[i] is zero on the first run, no if required
92
122
  g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
123
+ if self.ns_params: g = g.reshape(g.shape[0], -1).newton_schulz(self.ns_steps, self.ns_params).reshape(g.shape)
124
+ # muon does post momentum weight decay
125
+ if not self.pre_wd and self.wd > 0: t = t.detach() * (1.0 - self.wd * self.lr)
93
126
  # popular momentum does pre learning rate update
94
127
  if not self.classic: g = g * r * self.lr
95
- t.assign((t.detach() - g).cast(t.dtype))
96
- return self.b
128
+ ret.append((t.detach() - g).cast(t.dtype))
129
+ return ret, self.b
97
130
 
98
- # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
99
- def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
131
+ # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 it's just Adam/W.
132
+ def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, fused=FUSE_OPTIM):
100
133
  """
101
134
  AdamW optimizer with optional weight decay.
102
135
 
103
- - Described: https://paperswithcode.com/method/adamw
104
136
  - Paper: https://arxiv.org/abs/1711.05101v3
105
137
  """
106
- return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)
107
- def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
138
+ return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, fused=fused)
139
+ def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, fused=FUSE_OPTIM):
108
140
  """
109
141
  Adam optimizer.
110
142
 
111
- - Described: https://paperswithcode.com/method/adam
112
143
  - Paper: https://arxiv.org/abs/1412.6980
113
144
  """
114
- return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
145
+ return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, fused=fused)
115
146
 
116
147
  class LAMB(Optimizer):
117
148
  """
118
149
  LAMB optimizer with optional weight decay.
119
150
 
120
- - Described: https://paperswithcode.com/method/lamb
121
151
  - Paper: https://arxiv.org/abs/1904.00962
122
152
  """
123
- def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
124
- super().__init__(params, lr)
153
+ def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, fused=FUSE_OPTIM):
154
+ super().__init__(params, lr, fused)
125
155
  self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
126
156
  self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
127
- self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
128
- self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
157
+ self.m = self._new_optim_param()
158
+ self.v = self._new_optim_param()
129
159
 
130
- def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]:
160
+ def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
161
+ ret = []
131
162
  self.b1_t *= self.b1
132
163
  self.b2_t *= self.b2
133
- for i, (t, g) in enumerate(zip(self.params, grads)):
134
- self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g)
135
- self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g))
164
+ for i, (t, g) in enumerate(zip(params, grads)):
165
+ self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
166
+ self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
136
167
  m_hat = self.m[i] / (1.0 - self.b1_t)
137
168
  v_hat = self.v[i] / (1.0 - self.b2_t)
138
169
  up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
139
170
  if not self.adam:
140
171
  r1 = t.detach().square().sum().sqrt()
141
172
  r2 = up.square().sum().sqrt()
142
- r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
173
+ r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
143
174
  else:
144
175
  r = 1.0
145
- t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
146
- return [self.b1_t, self.b2_t] + self.m + self.v
176
+ ret.append((t.detach() - self.lr * r * up).cast(t.dtype))
177
+ return ret, [self.b1_t, self.b2_t] + self.m + self.v
tinygrad/nn/state.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
2
2
  from collections import OrderedDict
3
- from typing import Union, Optional, Any, Callable, BinaryIO, Iterable
3
+ from typing import Any, Callable, BinaryIO, Iterable
4
4
  from tinygrad.tensor import Tensor
5
5
  from tinygrad.dtype import dtypes
6
6
  from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
@@ -35,22 +35,22 @@ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dt
35
35
  "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
36
36
  inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
37
37
 
38
- def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Union[Tensor, str, pathlib.Path]], T]:
38
+ def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Tensor|str|pathlib.Path], T]:
39
39
  @functools.wraps(func)
40
- def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
40
+ def wrapper(fn: Tensor|str|pathlib.Path) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
41
41
  return wrapper
42
42
 
43
43
  @accept_filename
44
44
  def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
45
45
  """
46
- Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
46
+ Loads a .safetensor file, returning the source tensor, data start position, and metadata.
47
47
  """
48
48
  data_start = int.from_bytes(t[0:8].data(), "little") + 8
49
49
  return t, data_start, json.loads(t[8:data_start].data().tobytes())
50
50
 
51
- def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
51
+ def safe_load(fn:Tensor|str|pathlib.Path) -> dict[str, Tensor]:
52
52
  """
53
- Loads a .safetensor file from disk, returning the state_dict.
53
+ Loads a .safetensor file, returning the `state_dict`.
54
54
 
55
55
  ```python
56
56
  state_dict = nn.state.safe_load("test.safetensor")
@@ -61,9 +61,9 @@ def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
61
61
  return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
62
62
  for k, v in metadata.items() if k != "__metadata__" }
63
63
 
64
- def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any]]=None):
64
+ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=None):
65
65
  """
66
- Saves a state_dict to disk in a .safetensor file with optional metadata.
66
+ Saves a `state_dict` to disk in a .safetensor file with optional metadata.
67
67
 
68
68
  ```python
69
69
  t = Tensor([1, 2, 3])
@@ -87,7 +87,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any
87
87
 
88
88
  def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
89
89
  """
90
- Returns a state_dict of the object, with optional prefix.
90
+ Returns a `state_dict` of the object, with optional prefix.
91
91
 
92
92
  ```python exec="true" source="above" session="tensor" result="python"
93
93
  class Net:
@@ -124,9 +124,9 @@ def get_parameters(obj) -> list[Tensor]:
124
124
  """
125
125
  return list(get_state_dict(obj).values())
126
126
 
127
- def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
127
+ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> list[Tensor]:
128
128
  """
129
- Loads a state_dict into a model.
129
+ Loads a `state_dict` into a model. Return the loaded Tensors.
130
130
 
131
131
  ```python
132
132
  class Net:
@@ -140,7 +140,9 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
140
140
  ```
141
141
  """
142
142
  start_mem_used = GlobalCounters.mem_used
143
- with Timing("loaded weights in ", lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s"):
143
+ ret = []
144
+ with Timing("loaded weights in ",
145
+ lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose):
144
146
  model_state_dict = get_state_dict(model)
145
147
  if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
146
148
  print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
@@ -152,15 +154,22 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
152
154
  if v.shape != state_dict[k].shape:
153
155
  raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
154
156
  if isinstance(v.device, tuple):
155
- if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize()
156
- else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize()
157
- else: v.replace(state_dict[k].to(v.device)).realize()
157
+ if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
158
+ else: v.replace(state_dict[k].shard(v.device, v.uop.axis))
159
+ else: v.replace(state_dict[k].to(v.device))
160
+ if realize: v.realize()
158
161
  if consume: del state_dict[k]
162
+ ret.append(v)
163
+ return ret
159
164
 
160
165
  @accept_filename
161
166
  def tar_extract(t: Tensor) -> dict[str, Tensor]:
162
167
  """
163
- Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
168
+ ```python
169
+ tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
170
+ ```
171
+
172
+ Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
164
173
 
165
174
  ```python
166
175
  tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
@@ -174,14 +183,18 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
174
183
  @accept_filename
175
184
  def torch_load(t:Tensor) -> dict[str, Tensor]:
176
185
  """
177
- Loads a torch .pth file from disk.
186
+ ```python
187
+ torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
188
+ ```
189
+
190
+ Loads a torch .pth file, returning the `state_dict`.
178
191
 
179
192
  ```python
180
193
  state_dict = nn.state.torch_load("test.pth")
181
194
  ```
182
195
  """
183
- offsets: dict[Union[str, int], int] = {}
184
- lens: dict[Union[str, int], int] = {}
196
+ offsets: dict[str|int, int] = {}
197
+ lens: dict[str|int, int] = {}
185
198
  def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
186
199
  #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
187
200
  lens[storage[2]] = storage[4] * storage[1].itemsize
@@ -292,13 +305,14 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
292
305
  @accept_filename
293
306
  def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
294
307
  """
295
- Loads a gguf file from a tensor.
308
+ Loads a .gguf file, returning the `kv_data` and `state_dict`.
296
309
 
297
310
  ```python
298
- fn = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
299
- gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
300
- kv_data, state_dict = gguf_load(gguf_tensor)
311
+ gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
312
+ kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
301
313
  ```
314
+
315
+ NOTE: The provided tensor must be on a device that supports execution.
302
316
  """
303
317
  reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
304
318
  def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
@@ -1,46 +1,13 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, Callable
3
- import functools, math
4
- from enum import Enum, auto
2
+ from typing import Callable, cast, TYPE_CHECKING
3
+ import functools, itertools
5
4
  from dataclasses import dataclass, field, replace
6
5
  from tinygrad.helpers import to_function_name, dedup, prod
7
- from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
8
- from tinygrad.dtype import DType
9
-
10
- class OptOps(Enum):
11
- TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
12
- GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
13
- def __lt__(self, x:OptOps): return self.value < x.value
14
-
15
- @dataclass(frozen=True, order=True)
16
- class Opt:
17
- op: OptOps
18
- axis: Optional[int] = None
19
- arg: Optional[int | tuple] = None
20
- def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
21
-
22
- @dataclass(frozen=True)
23
- class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
24
- dims: tuple[int,int,int] # N, M, K
25
- threads: int # number of threads that construct the warp
26
- elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
27
- dtype_in: DType # dtype for A and B
28
- dtype_out: DType # dtype for C and D
29
- opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifing kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
30
- swizzle: tuple[Optional[tuple[tuple[int, ...], tuple[int, ...]]], Optional[tuple[tuple[int, ...], tuple[int, ...]]]] = (None, None)
31
- def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
32
- def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
33
- def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
34
- def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
35
- def __post_init__(self):
36
- local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
37
- assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), (
38
- f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})")
39
- assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
40
- assert 2**upcast_axes == self.elements_per_thread[2], (
41
- f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}")
42
- assert all(len(perm[0]) == local_axes and len(perm[1]) == reduce_axes + upcast_axes for perm in self.swizzle if perm), (
43
- f"swizzle perm should be of len (({local_axes})({reduce_axes + upcast_axes}))")
6
+ from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
7
+ from tinygrad.dtype import AddrSpace, PtrDType
8
+ if TYPE_CHECKING:
9
+ from tinygrad.codegen.opt.tc import TensorCore
10
+ from tinygrad.codegen.opt.kernel import Opt
44
11
 
45
12
  @dataclass(frozen=True)
46
13
  class Estimates:
@@ -61,19 +28,23 @@ class Estimates:
61
28
  dont_count: set[UOp] = set()
62
29
  if ignore_indexing:
63
30
  for u in uops:
64
- if u.op in {Ops.LOAD, Ops.STORE}:
65
- dont_count = dont_count.union(u.src[0].toposort)
66
- if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort)
31
+ if u.op in {Ops.LOAD, Ops.STORE} and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
32
+ dont_count = dont_count.union(u.src[0].toposort())
33
+ if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
67
34
  elif u.op is Ops.IF:
68
- dont_count = dont_count.union(u.src[0].toposort)
35
+ dont_count = dont_count.union(u.src[0].toposort())
69
36
  for u in uops:
70
37
  if u.op is Ops.RANGE:
71
38
  mult_stack.append(mults)
72
- mults *= (u.src[1] - u.src[0]).ssimplify()
39
+ mults *= cast(sint, u.src[0].ssimplify())
40
+ # SPECIAL are already counted in mults
41
+ mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
73
42
  elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
74
43
  elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
75
- elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults
76
- elif u.op is Ops.STORE: lds += u.src[1].dtype.itemsize * mults
44
+ elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
45
+ lds += u.dtype.itemsize * mults
46
+ elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
47
+ lds += u.src[1].dtype.itemsize * mults
77
48
  elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
78
49
  elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
79
50
  return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
@@ -84,13 +55,11 @@ class ProgramSpec:
84
55
  src:str
85
56
  device:str
86
57
  ast:UOp # save the base ast (this is method cache key)
87
- uops:Optional[list[UOp]]=None
88
- applied_opts:Optional[list[Opt]]=None
89
- mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
58
+ uops:list[UOp]|None=None
90
59
 
91
60
  # filled in from uops (if we have uops)
92
- global_size:Optional[list[int]]=None
93
- local_size:Optional[list[int]]=None
61
+ global_size:list[int]|None=None
62
+ local_size:list[int]|None=None
94
63
  vars:list[Variable]=field(default_factory=list)
95
64
  globals:list[int]=field(default_factory=list)
96
65
  outs:list[int]=field(default_factory=list)
@@ -103,19 +72,26 @@ class ProgramSpec:
103
72
  for u in self.uops:
104
73
  if u.op is Ops.DEFINE_VAR: self.vars.append(u)
105
74
  if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
106
- if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
107
- if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
75
+ if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL])
76
+ if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL])
108
77
  if u.op is Ops.SPECIAL:
109
78
  # NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
110
79
  if u.arg[0][0] == 'i': self.local_size = None
111
80
  special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
112
- assert special_size is not None
113
- special_size[int(u.arg[0][-1])] = u.arg[1]
81
+ if special_size is not None: special_size[int(u.arg[0][-1])] = u.arg[1]
114
82
  self.vars = sorted(self.vars, key=lambda v: v.arg)
115
83
  self.outs = sorted(dedup(self.outs))
116
84
  self.ins = sorted(dedup(self.ins))
117
85
  self._ran_post_init = True
118
86
 
87
+ @functools.cached_property
88
+ def mem_estimate(self) -> sint:
89
+ # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
90
+ # TODO: these max and min don't work on symbolic, and results are very wrong.
91
+ return sum(max(x.src[0].dtype.nbytes() for x in group)
92
+ for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in {Ops.LOAD, Ops.STORE} and x.src[0].base.op is Ops.DEFINE_GLOBAL],
93
+ key=lambda x: (x.op, x.src[0].base.arg)))
94
+
119
95
  @functools.cached_property
120
96
  def estimates(self) -> Estimates:
121
97
  return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)
@@ -123,6 +99,10 @@ class ProgramSpec:
123
99
  @functools.cached_property
124
100
  def function_name(self) -> str: return to_function_name(self.name)
125
101
 
102
+ @property
103
+ def applied_opts(self) -> tuple[Opt, ...]|None: return self.uops[-1].arg.applied_opts if \
104
+ self.uops is not None and self.uops[-1].op is Ops.SINK and self.uops[-1].arg is not None else None
105
+
126
106
  def launch_dims(self, var_vals:dict[Variable, int]):
127
107
  global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
128
108
  local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
@@ -136,12 +116,12 @@ class Renderer:
136
116
  has_local: bool = True
137
117
  has_shared: bool = True
138
118
  # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
139
- global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
140
- local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
119
+ global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
120
+ local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
141
121
  shared_max: int = 32768
142
122
  tensor_cores: list[TensorCore] = []
143
- pre_matcher: Optional[PatternMatcher] = None
144
- extra_matcher: Optional[PatternMatcher] = None
123
+ pre_matcher: PatternMatcher|None = None
124
+ extra_matcher: PatternMatcher|None = None
145
125
  code_for_op: dict[Ops, Callable] = {}
146
126
 
147
127
  def __reduce__(self): return self.__class__, ()