tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,26 @@
1
+ import ctypes, ctypes.util, os, sys, subprocess
2
+ from tinygrad.helpers import DEBUG, OSX, getenv
3
+
4
+ if sys.platform == 'win32':
5
+ # Windows llvm distribution doesn't seem to add itself to PATH or anywhere else where it can be easily retrieved from.
6
+ # winget also doesn't have something like `brew --prefix llvm` so just hardcode default installation path with an option to override
7
+ LLVM_PATH = getenv('LLVM_PATH', 'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll')
8
+ if not os.path.exists(LLVM_PATH):
9
+ raise FileNotFoundError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
10
+ elif OSX:
11
+ # Will raise FileNotFoundError if brew is not installed
12
+ brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
13
+ # `brew --prefix` will return even if formula is not installed
14
+ if not os.path.exists(brew_prefix):
15
+ raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm`')
16
+ LLVM_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
17
+ else:
18
+ LLVM_PATH = ctypes.util.find_library('LLVM')
19
+ # use newer LLVM if possible
20
+ for ver in reversed(range(14, 19+1)):
21
+ if LLVM_PATH is not None: break
22
+ LLVM_PATH = ctypes.util.find_library(f'LLVM-{ver}')
23
+ if LLVM_PATH is None:
24
+ raise FileNotFoundError("No LLVM library found on the system. Install it via your distro's package manager and ensure it's findable as 'LLVM'")
25
+
26
+ if DEBUG>=3: print(f'Using LLVM at {repr(LLVM_PATH)}')
@@ -1,30 +1,79 @@
1
1
  # ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
2
2
  from __future__ import annotations
3
3
  from dataclasses import dataclass
4
- from typing import Tuple, List, Optional, Dict, Set
4
+ import functools
5
+ from typing import Optional, Callable
5
6
  from tinygrad.helpers import merge_dicts, getenv
6
- from tinygrad.shape.view import View, strides_for_shape
7
+ from tinygrad.shape.view import View, strides_for_shape, unravel
7
8
  from tinygrad.dtype import dtypes
8
- from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
9
+ from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
10
+ from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
11
+
12
+ def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
13
+
14
+ # If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
15
+ # or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
16
+ def upcast(u: UOp):
17
+ srcs = tuple(upcast(_src) for _src in u.src)
18
+ if u.dtype.scalar() is dtypes.int:
19
+ dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
20
+ upcasted = u.replace(dtype=dtype, src=tuple([_src.cast(dtype) for _src in srcs]))
21
+ if overflow(u): return upcasted
22
+ # Check the original src, new srcs has Ops.CAST whose vmin, vmax change the real bounds
23
+ # Cast back is required because if the node is in range, siblings would never be upcasted
24
+ if any((overflow(src) for src in u.src)): return upcasted.cast(u.dtype)
25
+ return u.replace(src=tuple(srcs))
26
+
27
+ # pooling op may overflow before folding causing unnecessary upcast
28
+ def folded_upcast(u: UOp):
29
+ with Context(TRACK_MATCH_STATS=0):
30
+ return upcast(graph_rewrite(u, sym, {}))
31
+
32
+ @functools.lru_cache(None)
33
+ def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
34
+ idx, valid = views[-1].to_indexed_uops(_idxs)
35
+ for view in reversed(views[0:-1]):
36
+ view = view.minify()
37
+ idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
38
+ return idx, valid
39
+
40
+ @functools.lru_cache(None)
41
+ def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
42
+ # NOTE: if a stride is not always valid, it will be None
43
+ if len(views) == 1 and views[-1].mask is None: return views[-1].strides
44
+ ret: list[Optional[sint]] = [None] * len(views[-1].shape)
45
+ idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
46
+ # TODO: always apply these in to_indexed_uops?
47
+ if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
48
+ if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
49
+ for c in split_uop(idx, Ops.ADD):
50
+ if c.op is Ops.RANGE: ret[c.arg] = 1
51
+ if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
52
+ if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
53
+ used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
54
+ ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
55
+ if not ignore_valid:
56
+ for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
57
+ return tuple(ret)
9
58
 
10
59
  @dataclass(frozen=True, order=True)
11
60
  class ShapeTracker:
12
- views: Tuple[View, ...]
61
+ views: tuple[View, ...]
13
62
 
14
63
  def __add__(self, st:ShapeTracker) -> ShapeTracker:
15
64
  ret = self
16
65
  for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
17
66
  return ret
18
67
 
19
- def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
20
- inverted_views:List[View] = []
68
+ def invert(self, out_shape:tuple[sint, ...]) -> Optional[ShapeTracker]:
69
+ inverted_views:list[View] = []
21
70
  for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
22
71
  if (inverted:= v.invert(s)) is None: return None
23
72
  inverted_views.append(inverted)
24
73
  return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
25
74
 
26
75
  @staticmethod
27
- def from_shape(shape:Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
76
+ def from_shape(shape:tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
28
77
 
29
78
  @property
30
79
  def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@@ -33,65 +82,43 @@ class ShapeTracker:
33
82
  def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
34
83
 
35
84
  @property
36
- def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
85
+ def shape(self) -> tuple[sint, ...]: return self.views[-1].shape
37
86
 
38
87
  @property
39
88
  def size(self) -> int: return self.views[-1].size()
40
89
 
41
- def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
90
+ def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
42
91
 
43
92
  def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
93
+ def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
94
+ idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
95
+ return folded_upcast(idx), folded_upcast(valid)
44
96
 
45
- def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
46
- idx, valid = self.views[-1].to_indexed_uops(_idxs)
47
- for view in reversed(self.views[0:-1]):
48
- view = view.minify()
49
- acc, idxs = 1, []
50
- for d in reversed(view.shape):
51
- idxs.append((idx//acc)%d)
52
- acc *= d
53
- idx, valid = view.to_indexed_uops(idxs[::-1], valid)
54
- return idx, valid
55
-
97
+ # upper bound on buffer size required to fit this shapetracker
56
98
  def real_size(self) -> int:
57
99
  if 0 in self.shape: return 0
58
- idx, valid = self.to_indexed_uops()
59
- if not valid.vmax: return 0
100
+ view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v)
101
+ idx, _ = views_to_indexed_uops((view,))
60
102
  assert idx.vmax < 1e12, f"real_size broken for {self}"
61
- return int(idx.vmax+1)
103
+ return int(idx.vmax + 1)
62
104
 
63
- def vars(self) -> Set[Variable]: return set().union(*[v.vars() for v in self.views])
105
+ def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
64
106
 
65
107
  @property
66
- def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
108
+ def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
67
109
 
68
- def unbind(self) -> Tuple[ShapeTracker, Dict[Variable, int]]:
110
+ def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
69
111
  unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
112
+ if all(len(x) == 0 for x in var_vals): return self, {}
70
113
  return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
71
114
 
72
- # NOTE: if a stride is not always valid, it will be None
73
- def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
74
- if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
75
- ret: List[Optional[sint]] = [None] * len(self.shape)
76
- idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
77
- # TODO: always apply these in to_indexed_uops?
78
- if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
79
- if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
80
- for c in split_uop(idx, Ops.ADD):
81
- if c.op is Ops.RANGE: ret[c.arg[0]] = 1
82
- if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
83
- if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
84
- used_ranges = [x.arg[0] for x in idx.sparents if x.op is Ops.RANGE]
85
- ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
86
- if not ignore_valid:
87
- for masked_axis in [x.arg[0] for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
88
- return tuple(ret)
89
-
90
- def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
115
+ def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
116
+ def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
91
117
 
92
118
  def axis_is_masked(self, axis:int) -> bool:
93
- _, valid = self.to_indexed_uops()
94
- return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE]
119
+ with Context(TRACK_MATCH_STATS=0):
120
+ _, valid = self.to_indexed_uops()
121
+ return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
95
122
 
96
123
  def simplify(self) -> ShapeTracker:
97
124
  if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
@@ -100,12 +127,17 @@ class ShapeTracker:
100
127
 
101
128
  # *** under this line are the movement ops ***
102
129
 
103
- def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
104
- def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
105
- def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
106
- def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
107
- def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
130
+ def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
131
+ def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
132
+ def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
133
+ def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
134
+ def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), ))
108
135
 
109
- def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
136
+ def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker:
110
137
  if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
111
138
  return ShapeTracker(self.views + (View.create(new_shape), ))
139
+
140
+ def mop(self, op, arg): return mops[op](self, arg)
141
+
142
+ mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
143
+ Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}