tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,26 @@
|
|
1
|
+
import ctypes, ctypes.util, os, sys, subprocess
|
2
|
+
from tinygrad.helpers import DEBUG, OSX, getenv
|
3
|
+
|
4
|
+
if sys.platform == 'win32':
|
5
|
+
# Windows llvm distribution doesn't seem to add itself to PATH or anywhere else where it can be easily retrieved from.
|
6
|
+
# winget also doesn't have something like `brew --prefix llvm` so just hardcode default installation path with an option to override
|
7
|
+
LLVM_PATH = getenv('LLVM_PATH', 'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll')
|
8
|
+
if not os.path.exists(LLVM_PATH):
|
9
|
+
raise FileNotFoundError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
|
10
|
+
elif OSX:
|
11
|
+
# Will raise FileNotFoundError if brew is not installed
|
12
|
+
brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
|
13
|
+
# `brew --prefix` will return even if formula is not installed
|
14
|
+
if not os.path.exists(brew_prefix):
|
15
|
+
raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm`')
|
16
|
+
LLVM_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
|
17
|
+
else:
|
18
|
+
LLVM_PATH = ctypes.util.find_library('LLVM')
|
19
|
+
# use newer LLVM if possible
|
20
|
+
for ver in reversed(range(14, 19+1)):
|
21
|
+
if LLVM_PATH is not None: break
|
22
|
+
LLVM_PATH = ctypes.util.find_library(f'LLVM-{ver}')
|
23
|
+
if LLVM_PATH is None:
|
24
|
+
raise FileNotFoundError("No LLVM library found on the system. Install it via your distro's package manager and ensure it's findable as 'LLVM'")
|
25
|
+
|
26
|
+
if DEBUG>=3: print(f'Using LLVM at {repr(LLVM_PATH)}')
|
tinygrad/shape/shapetracker.py
CHANGED
@@ -1,30 +1,79 @@
|
|
1
1
|
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
2
2
|
from __future__ import annotations
|
3
3
|
from dataclasses import dataclass
|
4
|
-
|
4
|
+
import functools
|
5
|
+
from typing import Optional, Callable
|
5
6
|
from tinygrad.helpers import merge_dicts, getenv
|
6
|
-
from tinygrad.shape.view import View, strides_for_shape
|
7
|
+
from tinygrad.shape.view import View, strides_for_shape, unravel
|
7
8
|
from tinygrad.dtype import dtypes
|
8
|
-
from tinygrad.ops import UOp, Ops, graph_rewrite,
|
9
|
+
from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
|
10
|
+
from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
|
11
|
+
|
12
|
+
def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
|
13
|
+
|
14
|
+
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
|
15
|
+
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
|
16
|
+
def upcast(u: UOp):
|
17
|
+
srcs = tuple(upcast(_src) for _src in u.src)
|
18
|
+
if u.dtype.scalar() is dtypes.int:
|
19
|
+
dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
|
20
|
+
upcasted = u.replace(dtype=dtype, src=tuple([_src.cast(dtype) for _src in srcs]))
|
21
|
+
if overflow(u): return upcasted
|
22
|
+
# Check the original src, new srcs has Ops.CAST whose vmin, vmax change the real bounds
|
23
|
+
# Cast back is required because if the node is in range, siblings would never be upcasted
|
24
|
+
if any((overflow(src) for src in u.src)): return upcasted.cast(u.dtype)
|
25
|
+
return u.replace(src=tuple(srcs))
|
26
|
+
|
27
|
+
# pooling op may overflow before folding causing unnecessary upcast
|
28
|
+
def folded_upcast(u: UOp):
|
29
|
+
with Context(TRACK_MATCH_STATS=0):
|
30
|
+
return upcast(graph_rewrite(u, sym, {}))
|
31
|
+
|
32
|
+
@functools.lru_cache(None)
|
33
|
+
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
34
|
+
idx, valid = views[-1].to_indexed_uops(_idxs)
|
35
|
+
for view in reversed(views[0:-1]):
|
36
|
+
view = view.minify()
|
37
|
+
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
|
38
|
+
return idx, valid
|
39
|
+
|
40
|
+
@functools.lru_cache(None)
|
41
|
+
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
|
42
|
+
# NOTE: if a stride is not always valid, it will be None
|
43
|
+
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
44
|
+
ret: list[Optional[sint]] = [None] * len(views[-1].shape)
|
45
|
+
idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
|
46
|
+
# TODO: always apply these in to_indexed_uops?
|
47
|
+
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
48
|
+
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
49
|
+
for c in split_uop(idx, Ops.ADD):
|
50
|
+
if c.op is Ops.RANGE: ret[c.arg] = 1
|
51
|
+
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
|
52
|
+
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
|
53
|
+
used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
|
54
|
+
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
55
|
+
if not ignore_valid:
|
56
|
+
for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
|
57
|
+
return tuple(ret)
|
9
58
|
|
10
59
|
@dataclass(frozen=True, order=True)
|
11
60
|
class ShapeTracker:
|
12
|
-
views:
|
61
|
+
views: tuple[View, ...]
|
13
62
|
|
14
63
|
def __add__(self, st:ShapeTracker) -> ShapeTracker:
|
15
64
|
ret = self
|
16
65
|
for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
|
17
66
|
return ret
|
18
67
|
|
19
|
-
def invert(self, out_shape:
|
20
|
-
inverted_views:
|
68
|
+
def invert(self, out_shape:tuple[sint, ...]) -> Optional[ShapeTracker]:
|
69
|
+
inverted_views:list[View] = []
|
21
70
|
for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
|
22
71
|
if (inverted:= v.invert(s)) is None: return None
|
23
72
|
inverted_views.append(inverted)
|
24
73
|
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
|
25
74
|
|
26
75
|
@staticmethod
|
27
|
-
def from_shape(shape:
|
76
|
+
def from_shape(shape:tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
|
28
77
|
|
29
78
|
@property
|
30
79
|
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
@@ -33,65 +82,43 @@ class ShapeTracker:
|
|
33
82
|
def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
|
34
83
|
|
35
84
|
@property
|
36
|
-
def shape(self) ->
|
85
|
+
def shape(self) -> tuple[sint, ...]: return self.views[-1].shape
|
37
86
|
|
38
87
|
@property
|
39
88
|
def size(self) -> int: return self.views[-1].size()
|
40
89
|
|
41
|
-
def reduce(self, axis:
|
90
|
+
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
42
91
|
|
43
92
|
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
|
93
|
+
def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
94
|
+
idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
95
|
+
return folded_upcast(idx), folded_upcast(valid)
|
44
96
|
|
45
|
-
|
46
|
-
idx, valid = self.views[-1].to_indexed_uops(_idxs)
|
47
|
-
for view in reversed(self.views[0:-1]):
|
48
|
-
view = view.minify()
|
49
|
-
acc, idxs = 1, []
|
50
|
-
for d in reversed(view.shape):
|
51
|
-
idxs.append((idx//acc)%d)
|
52
|
-
acc *= d
|
53
|
-
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
|
54
|
-
return idx, valid
|
55
|
-
|
97
|
+
# upper bound on buffer size required to fit this shapetracker
|
56
98
|
def real_size(self) -> int:
|
57
99
|
if 0 in self.shape: return 0
|
58
|
-
|
59
|
-
|
100
|
+
view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v)
|
101
|
+
idx, _ = views_to_indexed_uops((view,))
|
60
102
|
assert idx.vmax < 1e12, f"real_size broken for {self}"
|
61
|
-
return int(idx.vmax+1)
|
103
|
+
return int(idx.vmax + 1)
|
62
104
|
|
63
|
-
def vars(self) ->
|
105
|
+
def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
|
64
106
|
|
65
107
|
@property
|
66
|
-
def var_vals(self) ->
|
108
|
+
def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
|
67
109
|
|
68
|
-
def unbind(self) ->
|
110
|
+
def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
|
69
111
|
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
112
|
+
if all(len(x) == 0 for x in var_vals): return self, {}
|
70
113
|
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
71
114
|
|
72
|
-
|
73
|
-
def
|
74
|
-
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
75
|
-
ret: List[Optional[sint]] = [None] * len(self.shape)
|
76
|
-
idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
|
77
|
-
# TODO: always apply these in to_indexed_uops?
|
78
|
-
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
79
|
-
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
80
|
-
for c in split_uop(idx, Ops.ADD):
|
81
|
-
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
|
82
|
-
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
|
83
|
-
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
|
84
|
-
used_ranges = [x.arg[0] for x in idx.sparents if x.op is Ops.RANGE]
|
85
|
-
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
86
|
-
if not ignore_valid:
|
87
|
-
for masked_axis in [x.arg[0] for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
|
88
|
-
return tuple(ret)
|
89
|
-
|
90
|
-
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
115
|
+
def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
|
116
|
+
def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
91
117
|
|
92
118
|
def axis_is_masked(self, axis:int) -> bool:
|
93
|
-
|
94
|
-
|
119
|
+
with Context(TRACK_MATCH_STATS=0):
|
120
|
+
_, valid = self.to_indexed_uops()
|
121
|
+
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
|
95
122
|
|
96
123
|
def simplify(self) -> ShapeTracker:
|
97
124
|
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
@@ -100,12 +127,17 @@ class ShapeTracker:
|
|
100
127
|
|
101
128
|
# *** under this line are the movement ops ***
|
102
129
|
|
103
|
-
def pad(self, arg:
|
104
|
-
def shrink(self, arg:
|
105
|
-
def expand(self, new_shape:
|
106
|
-
def permute(self, axis:
|
107
|
-
def
|
130
|
+
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
|
131
|
+
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
|
132
|
+
def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
|
133
|
+
def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
|
134
|
+
def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), ))
|
108
135
|
|
109
|
-
def reshape(self, new_shape:
|
136
|
+
def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker:
|
110
137
|
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
|
111
138
|
return ShapeTracker(self.views + (View.create(new_shape), ))
|
139
|
+
|
140
|
+
def mop(self, op, arg): return mops[op](self, arg)
|
141
|
+
|
142
|
+
mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
|
143
|
+
Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}
|