tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,32 @@
1
+ import ctypes, ctypes.util, os, sys, subprocess
2
+ from tinygrad.helpers import DEBUG, OSX, getenv
3
+
4
+ if sys.platform == 'win32':
5
+ # Windows llvm distribution doesn't seem to add itself to PATH or anywhere else where it can be easily retrieved from.
6
+ # winget also doesn't have something like `brew --prefix llvm` so just hardcode default installation path with an option to override
7
+ LLVM_PATH = getenv('LLVM_PATH', 'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll')
8
+ if not os.path.exists(LLVM_PATH):
9
+ raise RuntimeError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
10
+ elif OSX and 'tinygrad.runtime.ops_metal' in sys.modules:
11
+ # Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL
12
+ # This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL
13
+ # library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there
14
+ # doesn't seem to be anything we can do.
15
+ LLVM_PATH = ctypes.util.find_library('tinyllvm')
16
+ if LLVM_PATH is None:
17
+ raise RuntimeError("LLVM can't be opened in the same process with metal. You can install llvm distribution which supports that via `brew install uuuvn/tinygrad/tinyllvm`") # noqa: E501
18
+ elif OSX:
19
+ brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
20
+ # `brew --prefix` will return even if formula is not installed
21
+ if not os.path.exists(brew_prefix):
22
+ raise RuntimeError('LLVM not found, you can install it with `brew install llvm`')
23
+ LLVM_PATH = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
24
+ else:
25
+ LLVM_PATH = ctypes.util.find_library('LLVM')
26
+ for ver in range(14, 19+1):
27
+ if LLVM_PATH is not None: break
28
+ LLVM_PATH = ctypes.util.find_library(f'LLVM-{ver}')
29
+ if LLVM_PATH is None:
30
+ raise RuntimeError("No LLVM library found on the system. Install it via your distro's package manager and ensure it's findable as 'LLVM'")
31
+
32
+ if DEBUG>=2: print(f'Using LLVM at {repr(LLVM_PATH)}')
@@ -1,30 +1,79 @@
1
1
  # ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
2
2
  from __future__ import annotations
3
3
  from dataclasses import dataclass
4
- from typing import Tuple, List, Optional, Dict, Set
4
+ import functools
5
+ from typing import Optional, Callable
5
6
  from tinygrad.helpers import merge_dicts, getenv
6
- from tinygrad.shape.view import View, strides_for_shape
7
+ from tinygrad.shape.view import View, strides_for_shape, unravel
7
8
  from tinygrad.dtype import dtypes
8
- from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
9
+ from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid, sint_to_uop, Context
10
+ from tinygrad.codegen.rewriter import sym
11
+
12
+ def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
13
+
14
+ # If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
15
+ # or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
16
+ def upcast(u: UOp):
17
+ srcs = tuple(upcast(_src) for _src in u.src)
18
+ if u.dtype.scalar() is dtypes.int:
19
+ dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
20
+ upcasted = u.replace(dtype=dtype, src=tuple([_src.cast(dtype) for _src in srcs]))
21
+ if overflow(u): return upcasted
22
+ # Check the original src, new srcs has Ops.CAST whose vmin, vmax change the real bounds
23
+ # Cast back is required because if the node is in range, siblings would never be upcasted
24
+ if any((overflow(src) for src in u.src)): return upcasted.cast(u.dtype)
25
+ return u.replace(src=tuple(srcs))
26
+
27
+ # pooling op may overflow before folding causing unnecessary upcast
28
+ def folded_upcast(u: UOp):
29
+ with Context(TRACK_MATCH_STATS=0):
30
+ return upcast(graph_rewrite(u, sym, {}))
31
+
32
+ @functools.lru_cache(None)
33
+ def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
34
+ idx, valid = views[-1].to_indexed_uops(_idxs)
35
+ for view in reversed(views[0:-1]):
36
+ view = view.minify()
37
+ idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
38
+ return idx, valid
39
+
40
+ @functools.lru_cache(None)
41
+ def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
42
+ # NOTE: if a stride is not always valid, it will be None
43
+ if len(views) == 1 and views[-1].mask is None: return views[-1].strides
44
+ ret: list[Optional[sint]] = [None] * len(views[-1].shape)
45
+ idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
46
+ # TODO: always apply these in to_indexed_uops?
47
+ if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
48
+ if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
49
+ for c in split_uop(idx, Ops.ADD):
50
+ if c.op is Ops.RANGE: ret[c.arg] = 1
51
+ if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
52
+ if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
53
+ used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
54
+ ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
55
+ if not ignore_valid:
56
+ for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
57
+ return tuple(ret)
9
58
 
10
59
  @dataclass(frozen=True, order=True)
11
60
  class ShapeTracker:
12
- views: Tuple[View, ...]
61
+ views: tuple[View, ...]
13
62
 
14
63
  def __add__(self, st:ShapeTracker) -> ShapeTracker:
15
64
  ret = self
16
65
  for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
17
66
  return ret
18
67
 
19
- def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
20
- inverted_views:List[View] = []
68
+ def invert(self, out_shape:tuple[sint, ...]) -> Optional[ShapeTracker]:
69
+ inverted_views:list[View] = []
21
70
  for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
22
71
  if (inverted:= v.invert(s)) is None: return None
23
72
  inverted_views.append(inverted)
24
73
  return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
25
74
 
26
75
  @staticmethod
27
- def from_shape(shape:Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
76
+ def from_shape(shape:tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
28
77
 
29
78
  @property
30
79
  def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@@ -33,65 +82,42 @@ class ShapeTracker:
33
82
  def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
34
83
 
35
84
  @property
36
- def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
85
+ def shape(self) -> tuple[sint, ...]: return self.views[-1].shape
37
86
 
38
87
  @property
39
88
  def size(self) -> int: return self.views[-1].size()
40
89
 
41
- def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
90
+ def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
42
91
 
43
92
  def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
93
+ def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
94
+ idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
95
+ return folded_upcast(idx), folded_upcast(valid)
44
96
 
45
- def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
46
- idx, valid = self.views[-1].to_indexed_uops(_idxs)
47
- for view in reversed(self.views[0:-1]):
48
- view = view.minify()
49
- acc, idxs = 1, []
50
- for d in reversed(view.shape):
51
- idxs.append((idx//acc)%d)
52
- acc *= d
53
- idx, valid = view.to_indexed_uops(idxs[::-1], valid)
54
- return idx, valid
55
-
97
+ # upper bound on buffer size required to fit this shapetracker
56
98
  def real_size(self) -> int:
57
99
  if 0 in self.shape: return 0
58
- idx, valid = self.to_indexed_uops()
59
- if not valid.vmax: return 0
100
+ view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v)
101
+ idx, _ = views_to_indexed_uops((view,))
60
102
  assert idx.vmax < 1e12, f"real_size broken for {self}"
61
- return int(idx.vmax+1)
103
+ return int(idx.vmax + 1)
62
104
 
63
- def vars(self) -> Set[Variable]: return set().union(*[v.vars() for v in self.views])
105
+ def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
64
106
 
65
107
  @property
66
- def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
108
+ def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
67
109
 
68
- def unbind(self) -> Tuple[ShapeTracker, Dict[Variable, int]]:
110
+ def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
69
111
  unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
70
112
  return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
71
113
 
72
- # NOTE: if a stride is not always valid, it will be None
73
- def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
74
- if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
75
- ret: List[Optional[sint]] = [None] * len(self.shape)
76
- idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
77
- # TODO: always apply these in to_indexed_uops?
78
- if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
79
- if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
80
- for c in split_uop(idx, Ops.ADD):
81
- if c.op is Ops.RANGE: ret[c.arg[0]] = 1
82
- if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
83
- if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
84
- used_ranges = [x.arg[0] for x in idx.sparents if x.op is Ops.RANGE]
85
- ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
86
- if not ignore_valid:
87
- for masked_axis in [x.arg[0] for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
88
- return tuple(ret)
89
-
90
- def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
114
+ def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
115
+ def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
91
116
 
92
117
  def axis_is_masked(self, axis:int) -> bool:
93
- _, valid = self.to_indexed_uops()
94
- return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE]
118
+ with Context(TRACK_MATCH_STATS=0):
119
+ _, valid = self.to_indexed_uops()
120
+ return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
95
121
 
96
122
  def simplify(self) -> ShapeTracker:
97
123
  if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
@@ -100,12 +126,17 @@ class ShapeTracker:
100
126
 
101
127
  # *** under this line are the movement ops ***
102
128
 
103
- def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
104
- def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
105
- def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
106
- def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
107
- def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
129
+ def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
130
+ def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
131
+ def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
132
+ def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
133
+ def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), ))
108
134
 
109
- def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
135
+ def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker:
110
136
  if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
111
137
  return ShapeTracker(self.views + (View.create(new_shape), ))
138
+
139
+ def mop(self, op, arg): return mops[op](self, arg)
140
+
141
+ mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
142
+ Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}