tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -4,21 +4,21 @@
4
4
  # this is the (living) definition of uops
5
5
  from typing import Tuple, List, Optional, Any, Dict
6
6
  import pickle, base64, itertools, time, struct
7
- from tinygrad.dtype import DType, dtypes, ImageDType
7
+ from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
8
8
  from tinygrad.helpers import all_same, getenv, flatten
9
9
  from tinygrad.device import Compiled, Compiler, Allocator
10
- from tinygrad.codegen.uops import UOpGraph, UOps
11
- from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
10
+ from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
12
11
  from tinygrad.renderer import Renderer
13
- from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
12
+ from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
14
13
 
15
14
  def _load(m, i):
15
+ if i is None: return 0.0
16
16
  if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
17
17
  return m[i]
18
18
 
19
19
  def load(inp, j=0):
20
- if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
21
- return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
20
+ if len(inp) == 3: return [_load(m, x+j if x is not None else None) if gate else default for (m,x),default,gate in zip(*inp)]
21
+ return [_load(m, x+j if x is not None else None) for m,x in inp[0]]
22
22
 
23
23
  def _store(m, i, v):
24
24
  if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
@@ -26,7 +26,7 @@ def _store(m, i, v):
26
26
 
27
27
  class PythonProgram:
28
28
  def __init__(self, name:str, lib:bytes):
29
- self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
29
+ self.uops: List[Tuple[Ops, Optional[DType], List[int], Any]] = pickle.loads(lib)
30
30
  def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
31
31
  st = time.perf_counter()
32
32
  warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
@@ -40,58 +40,59 @@ class PythonProgram:
40
40
  loop_ends: Dict[int, int] = {}
41
41
  while i < len(self.uops):
42
42
  uop, dtype, idp, arg = self.uops[i]
43
- void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
44
- if uop is UOps.DEFINE_ACC: idp = [idp[0]]
43
+ void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF}
44
+ if uop is Ops.DEFINE_ACC: idp = [idp[0]]
45
45
  inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
46
46
  dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
47
47
  if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
48
- if uop is UOps.STORE:
49
- if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
50
- if isinstance(dtp[0], ImageDType):
51
- # image store
52
- assert dtp[2].count == 4
53
- for j,val in enumerate(inp[2]):
54
- for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
55
- assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
56
- if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
57
- elif dtp[2].count > 1:
58
- for j,val in enumerate(inp[2]):
59
- for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
48
+ if uop is Ops.STORE:
49
+ if len(inp) == 2: inp.append([True] * len(inp[0])) # set the gate to True
50
+ if dtp[1].count > 1:
51
+ for j,val in enumerate(inp[1]):
52
+ for (m,o),v,g in zip(inp[0], val, inp[2]):
60
53
  if g: _store(m, o+j, v)
61
54
  else:
62
- for m,o,v,g in zip(*inp):
55
+ for (m,o),v,g in zip(*inp):
63
56
  if g: _store(m, o, v)
64
57
  i += 1
65
58
  continue
66
- if uop is UOps.ENDRANGE:
59
+ if uop is Ops.ENDRANGE:
67
60
  loop_ends[idp[0]] = i
68
61
  i = idp[0]
69
62
  continue
70
- if uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
63
+ if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF):
71
64
  # in the python emulator, the warp is always in sync
72
65
  i += 1
73
66
  continue
74
67
  assert dtype is not None, f"{uop} is missing a dtype"
75
68
  dl[i] = dtype
76
- if uop is UOps.DEFINE_GLOBAL:
69
+ if uop is Ops.DEFINE_GLOBAL:
77
70
  assert dtype.fmt is not None
78
71
  ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
79
- elif uop is UOps.DEFINE_LOCAL:
72
+ elif uop is Ops.DEFINE_LOCAL:
80
73
  assert dtype.fmt is not None
81
74
  lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
82
75
  ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
83
- elif uop is UOps.DEFINE_VAR:
76
+ elif uop is Ops.DEFINE_VAR:
84
77
  ul[i] = [pvals.pop(0)] * warp_size
85
- elif uop is UOps.SPECIAL:
86
- if arg[1][0] == 'g':
87
- ul[i] = [idxs[2-arg[0]]] * warp_size
88
- elif arg[1][0] == 'l':
89
- ul[i] = [x[2-arg[0]] for x in warp]
90
- elif uop is UOps.CONST:
91
- ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
92
- elif uop is UOps.DEFINE_ACC:
93
- ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
94
- elif uop is UOps.RANGE:
78
+ elif uop is Ops.SPECIAL:
79
+ if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
80
+ elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
81
+ elif uop is Ops.CONST: ul[i] = [arg] * warp_size
82
+ elif uop is Ops.DEFINE_ACC:
83
+ ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
84
+ elif uop is Ops.INDEX:
85
+ ret = []
86
+ if isinstance(dtp[0], ImageDType):
87
+ for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
88
+ if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None))
89
+ else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4))
90
+ else:
91
+ for m,o in zip(inp[0], inp[1]): ret.append((m,o))
92
+ ul[i] = ret
93
+ elif uop is Ops.CAST and isinstance(dtype, PtrDType):
94
+ ul[i] = inp[0]
95
+ elif uop is Ops.RANGE:
95
96
  if i not in ul: ul[i] = [inp[0][0]] * warp_size
96
97
  else:
97
98
  for j in range(len(ul[i])):
@@ -100,45 +101,29 @@ class PythonProgram:
100
101
  del ul[i]
101
102
  i = loop_ends[i] + 1
102
103
  continue
103
- elif uop in (UOps.CAST, UOps.BITCAST):
104
- if dtype.count > 1: ul[i] = inp
105
- else:
106
- assert dtp[0].fmt and dtype.fmt
107
- pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
108
- if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
109
- else:
110
- casted = [dtypes.as_const(x, dtype) for x in inp[0]]
111
- if dtypes.is_int(dtype):
112
- overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
113
- casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
114
- elif dtypes.is_float(dtype):
115
- casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
116
- ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
117
- elif uop is UOps.LOAD:
118
- if isinstance(dtp[0], ImageDType):
119
- assert dtype.count == 4
120
- ul[i] = []
121
- for j in range(dtype.count):
122
- ret = []
123
- for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
124
- if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
125
- else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
126
- ul[i].append(ret)
127
- elif dtype.count > 1:
128
- ul[i] = [load([inp[i][j] if dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
104
+ elif uop is Ops.VECTORIZE: ul[i] = inp
105
+ elif uop in {Ops.CAST, Ops.BITCAST}:
106
+ assert dtp[0].fmt and dtype.fmt
107
+ pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
108
+ if uop is Ops.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
109
+ else: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
110
+ elif uop is Ops.LOAD:
111
+ if dtype.count > 1:
112
+ ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
129
113
  else:
130
114
  ul[i] = load(inp)
131
- elif uop is UOps.PHI:
115
+ elif uop is Ops.ASSIGN:
132
116
  for j in range(len(inp[0])): inp[0][j] = inp[1][j]
133
117
  ul[i] = inp[0]
134
- elif uop is UOps.GEP:
135
- ul[i] = inp[0][arg]
136
- elif uop is UOps.WMMA:
118
+ elif uop is Ops.GEP:
119
+ assert len(arg) == 1
120
+ ul[i] = inp[0][arg[0]]
121
+ elif uop is Ops.WMMA:
137
122
  # here are the models for the WMMA instruction on the different hardware
138
123
  def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
139
- assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
140
- assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
141
- assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
124
+ assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
125
+ assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}"
126
+ assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}"
142
127
  assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
143
128
  assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
144
129
  assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
@@ -152,13 +137,13 @@ class PythonProgram:
152
137
  return out
153
138
 
154
139
  # TODO: refactor these to a shared TensorCoreLayout in kernel.py
155
- if arg[5] == "METAL":
140
+ if arg[4] == "METAL":
156
141
  # A (2 elements on 32 threads): row major
157
142
  def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
158
143
  # (i, j), C, D (2 elements on 32 threads): row major same as A/B
159
144
  def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
160
145
  ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
161
- elif arg[5] == "AMD":
146
+ elif arg[4] == "AMD":
162
147
  # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
163
148
  def a_elem(x, i, j, goff):
164
149
  assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
@@ -167,7 +152,7 @@ class PythonProgram:
167
152
  def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
168
153
  def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
169
154
  ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
170
- elif arg[5] == "CUDA":
155
+ elif arg[4] == "CUDA":
171
156
  # A (8 elements on 32 threads)
172
157
  def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
173
158
  # B (4 elements on 32 threads)
@@ -175,11 +160,23 @@ class PythonProgram:
175
160
  # (i, j), C, D (4 elements on 32 threads)
176
161
  def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
177
162
  ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
163
+ elif arg[4] == "INTEL":
164
+ # A (16 elements on 8 threads)
165
+ def a_elem(x, i, j, goff): return x[i%2+j*2][goff+i//2]
166
+ # B (16 elements on 8 threads)
167
+ def b_elem(x, i, j, goff): return x[j][goff+i]
168
+ # C, D (8 elements on 8 threads)
169
+ def c_map(lane, elem): return (lane, elem)
170
+ ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
171
+ elif arg[4] == "CLANG":
172
+ def elem(x, i, j, _): return x[i+j][0]
173
+ def c_map(_, elem): return (elem%16, elem//16)
174
+ ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
178
175
  else: raise NotImplementedError(f"unimplemented tensor core {arg}")
179
- elif uop is UOps.ALU:
180
- assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
181
- assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
182
- ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
176
+ elif uop in GroupOp.ALU:
177
+ assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
178
+ assert all_same([dtype] + dtp) or uop in {Ops.CMPNE, Ops.CMPLT, Ops.WHERE}, f"dtype mismatch on {uop}"
179
+ ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
183
180
  assert i in ul, (uop, dtype, idp, arg)
184
181
  i += 1
185
182
  return time.perf_counter() - st
@@ -190,9 +187,11 @@ class PythonRenderer(Renderer):
190
187
  if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
191
188
  if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
192
189
  if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
190
+ if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
191
+ if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CLANG", ClangRenderer.tensor_cores
193
192
 
194
- def render(self, name:str, uops:UOpGraph) -> str:
195
- lops = [(u.op, u.dtype, [uops.uops.index(v) for v in u.src], u.arg) for u in uops]
193
+ def render(self, name:str, uops:List[UOp]) -> str:
194
+ lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
196
195
  return base64.b64encode(pickle.dumps(lops)).decode()
197
196
 
198
197
  class PythonCompiler(Compiler):