tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ # pylint: disable=cell-var-from-loop
1
2
  # a python uops emulator
2
3
  # works to test the tensor cores, and all the uops in general
3
4
  # this is the (living) definition of uops
@@ -7,9 +8,9 @@ from tinygrad.dtype import DType, dtypes, ImageDType
7
8
  from tinygrad.helpers import all_same, getenv, flatten
8
9
  from tinygrad.device import Compiled, Compiler, Allocator
9
10
  from tinygrad.codegen.uops import UOpGraph, UOps
10
- from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
11
+ from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
11
12
  from tinygrad.renderer import Renderer
12
- from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, HIPRenderer
13
+ from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
13
14
 
14
15
  def _load(m, i):
15
16
  if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
@@ -17,7 +18,7 @@ def _load(m, i):
17
18
 
18
19
  def load(inp, j=0):
19
20
  if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
20
- else: return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
21
+ return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
21
22
 
22
23
  def _store(m, i, v):
23
24
  if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
@@ -40,7 +41,7 @@ class PythonProgram:
40
41
  while i < len(self.uops):
41
42
  uop, dtype, idp, arg = self.uops[i]
42
43
  void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
43
- if uop is UOps.DEFINE_ACC: idp.clear()
44
+ if uop is UOps.DEFINE_ACC: idp = [idp[0]]
44
45
  inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
45
46
  dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
46
47
  if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
@@ -62,11 +63,11 @@ class PythonProgram:
62
63
  if g: _store(m, o, v)
63
64
  i += 1
64
65
  continue
65
- elif uop is UOps.ENDRANGE:
66
+ if uop is UOps.ENDRANGE:
66
67
  loop_ends[idp[0]] = i
67
68
  i = idp[0]
68
69
  continue
69
- elif uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
70
+ if uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
70
71
  # in the python emulator, the warp is always in sync
71
72
  i += 1
72
73
  continue
@@ -89,7 +90,7 @@ class PythonProgram:
89
90
  elif uop is UOps.CONST:
90
91
  ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
91
92
  elif uop is UOps.DEFINE_ACC:
92
- ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
93
+ ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
93
94
  elif uop is UOps.RANGE:
94
95
  if i not in ul: ul[i] = [inp[0][0]] * warp_size
95
96
  else:
@@ -99,7 +100,7 @@ class PythonProgram:
99
100
  del ul[i]
100
101
  i = loop_ends[i] + 1
101
102
  continue
102
- elif uop in {UOps.CAST, UOps.BITCAST}:
103
+ elif uop in (UOps.CAST, UOps.BITCAST):
103
104
  if dtype.count > 1: ul[i] = inp
104
105
  else:
105
106
  assert dtp[0].fmt and dtype.fmt
@@ -107,9 +108,12 @@ class PythonProgram:
107
108
  if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
108
109
  else:
109
110
  casted = [dtypes.as_const(x, dtype) for x in inp[0]]
110
- overflow_adjust = 2**(dtype.itemsize*8 - 1) if (dtypes.is_int(dtype) and not dtypes.is_unsigned(dtype)) else 0
111
- overflow_fixed = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) if dtypes.is_int(dtype) else x for x in casted]
112
- ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *overflow_fixed)))
111
+ if dtypes.is_int(dtype):
112
+ overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
113
+ casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
114
+ elif dtypes.is_float(dtype):
115
+ casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
116
+ ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
113
117
  elif uop is UOps.LOAD:
114
118
  if isinstance(dtp[0], ImageDType):
115
119
  assert dtype.count == 4
@@ -154,13 +158,13 @@ class PythonProgram:
154
158
  # (i, j), C, D (2 elements on 32 threads): row major same as A/B
155
159
  def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
156
160
  ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
157
- elif arg[5] == "HSA":
161
+ elif arg[5] == "AMD":
158
162
  # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
159
163
  def a_elem(x, i, j, goff):
160
164
  assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
161
165
  return x[i][goff+j]
162
166
  # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
163
- def b_elem(x, i, j, goff): return a_elem(x, j, i, goff)
167
+ def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
164
168
  def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
165
169
  ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
166
170
  elif arg[5] == "CUDA":
@@ -174,7 +178,7 @@ class PythonProgram:
174
178
  else: raise NotImplementedError(f"unimplemented tensor core {arg}")
175
179
  elif uop is UOps.ALU:
176
180
  assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
177
- assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
181
+ assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
178
182
  ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
179
183
  assert i in ul, (uop, dtype, idp, arg)
180
184
  i += 1
@@ -184,11 +188,11 @@ class PythonRenderer(Renderer):
184
188
  device = "PYTHON"
185
189
  def __init__(self):
186
190
  if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
187
- if getenv("EMULATE_HSA"): self.device, self.tensor_cores = "HSA", HIPRenderer.tensor_cores
191
+ if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
188
192
  if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
189
193
 
190
194
  def render(self, name:str, uops:UOpGraph) -> str:
191
- lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
195
+ lops = [(u.op, u.dtype, [uops.uops.index(v) for v in u.src], u.arg) for u in uops]
192
196
  return base64.b64encode(pickle.dumps(lops)).decode()
193
197
 
194
198
  class PythonCompiler(Compiler):
File without changes
@@ -25,8 +25,11 @@ class ShapeTracker:
25
25
  return ret
26
26
 
27
27
  def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
28
- ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
29
- return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None
28
+ inverted_views:List[View] = []
29
+ for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
30
+ if (inverted:= v.invert(s)) is None: return None
31
+ inverted_views.append(inverted)
32
+ return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
30
33
 
31
34
  @staticmethod
32
35
  def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
@@ -22,11 +22,9 @@ class Node:
22
22
 
23
23
  @functools.cached_property
24
24
  def key(self) -> str: return self.render(ctx="DEBUG")
25
- @functools.cached_property
26
- def hash(self) -> int: return hash(self.key)
27
25
  def __repr__(self): return self.render(ctx="REPR")
28
26
  def __str__(self): return "<"+self.key+">"
29
- def __hash__(self): return self.hash
27
+ def __hash__(self): return hash(self.key)
30
28
  def __bool__(self): return not (self.max == self.min == 0)
31
29
  def __eq__(self, other:object) -> bool:
32
30
  if not isinstance(other, Node): return NotImplemented
tinygrad/shape/view.py CHANGED
@@ -3,7 +3,7 @@ import functools, operator, itertools, math
3
3
  from dataclasses import dataclass
4
4
  from typing import Tuple, List, Optional, Dict, Set, cast
5
5
  from tinygrad.helpers import prod, all_int, argsort
6
- from tinygrad.shape.symbolic import Node, NumNode, Variable, sint
6
+ from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
7
7
 
8
8
  @functools.lru_cache(maxsize=None)
9
9
  def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
@@ -12,24 +12,26 @@ def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tu
12
12
  @functools.lru_cache(maxsize=None)
13
13
  def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
14
14
  if not shape: return ()
15
- strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))
16
- return canonicalize_strides(shape, strides[::-1])
15
+ strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
16
+ return canonicalize_strides(shape, strides)
17
17
 
18
18
  @functools.lru_cache(maxsize=None)
19
19
  def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
20
- # merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...]
21
- if not shape: return tuple()
22
- assert len(shape) == len(strides)
20
+ # merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
21
+ if not shape: return ()
22
+ assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
23
23
  ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
24
- # wrt merging zero strided dimensions
25
- merging = strides[0] == 0 and (mask[0][1] - mask[0][0] == 1 if mask else shape[0] == 1)
26
- for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
27
- if sh == 1: continue
28
- if merging or ret[-1][1] == sh * st: # mergeable
29
- ret[-1] = (ret[-1][0] * sh, st, (sh if merging else ret[-1][2] * sh) if st else 0)
30
- else: ret.append((sh, st, sh if st else 0)) # begin new
31
- # merging ends with either non-zero strided dim or zero strided dim with mask range > 1
32
- merging = st == 0 and (mask[i][1] - mask[i][0] == 1 if mask else sh == 1)
24
+ # merge this dim to next dim if size is 1
25
+ merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
26
+ for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
27
+ last_s, last_st, last_pre_expand_s = ret[-1]
28
+ # always merge 1
29
+ if s == 1: continue
30
+ # merge last dim with this dim if merging or strides matched
31
+ if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st else 0)
32
+ else: ret.append((s, st, s if st else 0))
33
+ # merge this dim to next dim if size is 1
34
+ merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
33
35
  return tuple(ret)
34
36
 
35
37
  @functools.lru_cache(maxsize=None)
@@ -96,9 +98,10 @@ class View:
96
98
  @functools.lru_cache(maxsize=None)
97
99
  def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
98
100
  strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
101
+ # canonicalize 0 in shape
102
+ if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
99
103
  # canonicalize empty mask
100
104
  if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
101
- contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
102
105
  # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
103
106
  # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
104
107
  # TODO: assert comparison with LtNode to avoid mis-using symbolic
@@ -107,6 +110,7 @@ class View:
107
110
  strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
108
111
  offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
109
112
  strides = tuple(0 if e else st for st,e in zip(strides, elim))
113
+ contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
110
114
  return View(shape, strides, offset, mask, contiguous)
111
115
 
112
116
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@@ -241,8 +245,7 @@ class View:
241
245
 
242
246
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
243
247
  def permute(self, axis: Tuple[int, ...]) -> View:
244
- assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
245
- assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
248
+ assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
246
249
  return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
247
250
  tuple(self.mask[a] for a in axis) if self.mask is not None else None)
248
251
 
@@ -266,7 +269,7 @@ class View:
266
269
  assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
267
270
  return View.create(new_shape)
268
271
  # check for the same size
269
- if all_int(self.shape):
272
+ if (self_all_int := all_int(self.shape)):
270
273
  assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
271
274
  if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
272
275
  raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
@@ -276,6 +279,18 @@ class View:
276
279
  # after the asserts, it's okay to check contiguous
277
280
  if self.contiguous: return View.create(new_shape)
278
281
 
282
+ # if it's not contiguous and new shape is symbolic, check if it's directly replaceable
283
+ if self_all_int and not all_int(new_shape):
284
+ if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
285
+ for si, so in zip(self.shape, new_shape):
286
+ if isinstance(so, int):
287
+ if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
288
+ else:
289
+ var_vals = {v: v.unbind()[1] for v in so.vars()}
290
+ if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
291
+ # all dimensions matched, return the new view directly
292
+ return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
293
+
279
294
  strides, r_new_shape = [], reversed(new_shape)
280
295
  for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
281
296
  acc = 1