tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,32 @@
|
|
1
|
+
import ctypes, ctypes.util, os, sys, subprocess
|
2
|
+
from tinygrad.helpers import DEBUG, OSX, getenv
|
3
|
+
|
4
|
+
if sys.platform == 'win32':
|
5
|
+
# Windows llvm distribution doesn't seem to add itself to PATH or anywhere else where it can be easily retrieved from.
|
6
|
+
# winget also doesn't have something like `brew --prefix llvm` so just hardcode default installation path with an option to override
|
7
|
+
LLVM_PATH = getenv('LLVM_PATH', 'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll')
|
8
|
+
if not os.path.exists(LLVM_PATH):
|
9
|
+
raise RuntimeError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
|
10
|
+
elif OSX and 'tinygrad.runtime.ops_metal' in sys.modules:
|
11
|
+
# Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL
|
12
|
+
# This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL
|
13
|
+
# library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there
|
14
|
+
# doesn't seem to be anything we can do.
|
15
|
+
LLVM_PATH = ctypes.util.find_library('tinyllvm')
|
16
|
+
if LLVM_PATH is None:
|
17
|
+
raise RuntimeError("LLVM can't be opened in the same process with metal. You can install llvm distribution which supports that via `brew install uuuvn/tinygrad/tinyllvm`") # noqa: E501
|
18
|
+
elif OSX:
|
19
|
+
brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
|
20
|
+
# `brew --prefix` will return even if formula is not installed
|
21
|
+
if not os.path.exists(brew_prefix):
|
22
|
+
raise RuntimeError('LLVM not found, you can install it with `brew install llvm`')
|
23
|
+
LLVM_PATH = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
|
24
|
+
else:
|
25
|
+
LLVM_PATH = ctypes.util.find_library('LLVM')
|
26
|
+
for ver in range(14, 19+1):
|
27
|
+
if LLVM_PATH is not None: break
|
28
|
+
LLVM_PATH = ctypes.util.find_library(f'LLVM-{ver}')
|
29
|
+
if LLVM_PATH is None:
|
30
|
+
raise RuntimeError("No LLVM library found on the system. Install it via your distro's package manager and ensure it's findable as 'LLVM'")
|
31
|
+
|
32
|
+
if DEBUG>=2: print(f'Using LLVM at {repr(LLVM_PATH)}')
|
tinygrad/shape/shapetracker.py
CHANGED
@@ -1,30 +1,79 @@
|
|
1
1
|
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
2
2
|
from __future__ import annotations
|
3
3
|
from dataclasses import dataclass
|
4
|
-
|
4
|
+
import functools
|
5
|
+
from typing import Optional, Callable
|
5
6
|
from tinygrad.helpers import merge_dicts, getenv
|
6
|
-
from tinygrad.shape.view import View, strides_for_shape
|
7
|
+
from tinygrad.shape.view import View, strides_for_shape, unravel
|
7
8
|
from tinygrad.dtype import dtypes
|
8
|
-
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
|
9
|
+
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid, sint_to_uop, Context
|
10
|
+
from tinygrad.codegen.rewriter import sym
|
11
|
+
|
12
|
+
def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
|
13
|
+
|
14
|
+
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
|
15
|
+
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
|
16
|
+
def upcast(u: UOp):
|
17
|
+
srcs = tuple(upcast(_src) for _src in u.src)
|
18
|
+
if u.dtype.scalar() is dtypes.int:
|
19
|
+
dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
|
20
|
+
upcasted = u.replace(dtype=dtype, src=tuple([_src.cast(dtype) for _src in srcs]))
|
21
|
+
if overflow(u): return upcasted
|
22
|
+
# Check the original src, new srcs has Ops.CAST whose vmin, vmax change the real bounds
|
23
|
+
# Cast back is required because if the node is in range, siblings would never be upcasted
|
24
|
+
if any((overflow(src) for src in u.src)): return upcasted.cast(u.dtype)
|
25
|
+
return u.replace(src=tuple(srcs))
|
26
|
+
|
27
|
+
# pooling op may overflow before folding causing unnecessary upcast
|
28
|
+
def folded_upcast(u: UOp):
|
29
|
+
with Context(TRACK_MATCH_STATS=0):
|
30
|
+
return upcast(graph_rewrite(u, sym, {}))
|
31
|
+
|
32
|
+
@functools.lru_cache(None)
|
33
|
+
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
34
|
+
idx, valid = views[-1].to_indexed_uops(_idxs)
|
35
|
+
for view in reversed(views[0:-1]):
|
36
|
+
view = view.minify()
|
37
|
+
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
|
38
|
+
return idx, valid
|
39
|
+
|
40
|
+
@functools.lru_cache(None)
|
41
|
+
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
|
42
|
+
# NOTE: if a stride is not always valid, it will be None
|
43
|
+
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
44
|
+
ret: list[Optional[sint]] = [None] * len(views[-1].shape)
|
45
|
+
idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
|
46
|
+
# TODO: always apply these in to_indexed_uops?
|
47
|
+
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
48
|
+
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
49
|
+
for c in split_uop(idx, Ops.ADD):
|
50
|
+
if c.op is Ops.RANGE: ret[c.arg] = 1
|
51
|
+
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
|
52
|
+
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
|
53
|
+
used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
|
54
|
+
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
55
|
+
if not ignore_valid:
|
56
|
+
for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
|
57
|
+
return tuple(ret)
|
9
58
|
|
10
59
|
@dataclass(frozen=True, order=True)
|
11
60
|
class ShapeTracker:
|
12
|
-
views:
|
61
|
+
views: tuple[View, ...]
|
13
62
|
|
14
63
|
def __add__(self, st:ShapeTracker) -> ShapeTracker:
|
15
64
|
ret = self
|
16
65
|
for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
|
17
66
|
return ret
|
18
67
|
|
19
|
-
def invert(self, out_shape:
|
20
|
-
inverted_views:
|
68
|
+
def invert(self, out_shape:tuple[sint, ...]) -> Optional[ShapeTracker]:
|
69
|
+
inverted_views:list[View] = []
|
21
70
|
for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
|
22
71
|
if (inverted:= v.invert(s)) is None: return None
|
23
72
|
inverted_views.append(inverted)
|
24
73
|
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
|
25
74
|
|
26
75
|
@staticmethod
|
27
|
-
def from_shape(shape:
|
76
|
+
def from_shape(shape:tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
|
28
77
|
|
29
78
|
@property
|
30
79
|
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
@@ -33,65 +82,42 @@ class ShapeTracker:
|
|
33
82
|
def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
|
34
83
|
|
35
84
|
@property
|
36
|
-
def shape(self) ->
|
85
|
+
def shape(self) -> tuple[sint, ...]: return self.views[-1].shape
|
37
86
|
|
38
87
|
@property
|
39
88
|
def size(self) -> int: return self.views[-1].size()
|
40
89
|
|
41
|
-
def reduce(self, axis:
|
90
|
+
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
42
91
|
|
43
92
|
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
|
93
|
+
def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
94
|
+
idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
95
|
+
return folded_upcast(idx), folded_upcast(valid)
|
44
96
|
|
45
|
-
|
46
|
-
idx, valid = self.views[-1].to_indexed_uops(_idxs)
|
47
|
-
for view in reversed(self.views[0:-1]):
|
48
|
-
view = view.minify()
|
49
|
-
acc, idxs = 1, []
|
50
|
-
for d in reversed(view.shape):
|
51
|
-
idxs.append((idx//acc)%d)
|
52
|
-
acc *= d
|
53
|
-
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
|
54
|
-
return idx, valid
|
55
|
-
|
97
|
+
# upper bound on buffer size required to fit this shapetracker
|
56
98
|
def real_size(self) -> int:
|
57
99
|
if 0 in self.shape: return 0
|
58
|
-
|
59
|
-
|
100
|
+
view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v)
|
101
|
+
idx, _ = views_to_indexed_uops((view,))
|
60
102
|
assert idx.vmax < 1e12, f"real_size broken for {self}"
|
61
|
-
return int(idx.vmax+1)
|
103
|
+
return int(idx.vmax + 1)
|
62
104
|
|
63
|
-
def vars(self) ->
|
105
|
+
def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
|
64
106
|
|
65
107
|
@property
|
66
|
-
def var_vals(self) ->
|
108
|
+
def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
|
67
109
|
|
68
|
-
def unbind(self) ->
|
110
|
+
def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
|
69
111
|
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
70
112
|
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
71
113
|
|
72
|
-
|
73
|
-
def
|
74
|
-
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
75
|
-
ret: List[Optional[sint]] = [None] * len(self.shape)
|
76
|
-
idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
|
77
|
-
# TODO: always apply these in to_indexed_uops?
|
78
|
-
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
79
|
-
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
80
|
-
for c in split_uop(idx, Ops.ADD):
|
81
|
-
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
|
82
|
-
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
|
83
|
-
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
|
84
|
-
used_ranges = [x.arg[0] for x in idx.sparents if x.op is Ops.RANGE]
|
85
|
-
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
86
|
-
if not ignore_valid:
|
87
|
-
for masked_axis in [x.arg[0] for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
|
88
|
-
return tuple(ret)
|
89
|
-
|
90
|
-
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
114
|
+
def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
|
115
|
+
def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
91
116
|
|
92
117
|
def axis_is_masked(self, axis:int) -> bool:
|
93
|
-
|
94
|
-
|
118
|
+
with Context(TRACK_MATCH_STATS=0):
|
119
|
+
_, valid = self.to_indexed_uops()
|
120
|
+
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
|
95
121
|
|
96
122
|
def simplify(self) -> ShapeTracker:
|
97
123
|
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
@@ -100,12 +126,17 @@ class ShapeTracker:
|
|
100
126
|
|
101
127
|
# *** under this line are the movement ops ***
|
102
128
|
|
103
|
-
def pad(self, arg:
|
104
|
-
def shrink(self, arg:
|
105
|
-
def expand(self, new_shape:
|
106
|
-
def permute(self, axis:
|
107
|
-
def
|
129
|
+
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
|
130
|
+
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
|
131
|
+
def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
|
132
|
+
def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
|
133
|
+
def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), ))
|
108
134
|
|
109
|
-
def reshape(self, new_shape:
|
135
|
+
def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker:
|
110
136
|
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
|
111
137
|
return ShapeTracker(self.views + (View.create(new_shape), ))
|
138
|
+
|
139
|
+
def mop(self, op, arg): return mops[op](self, arg)
|
140
|
+
|
141
|
+
mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
|
142
|
+
Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}
|