tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py
CHANGED
@@ -2,26 +2,25 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
|
2
2
|
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
3
3
|
from collections import defaultdict
|
4
4
|
from dataclasses import replace
|
5
|
+
from tinygrad.ops import UOp, Ops, Variable, sym_infer
|
5
6
|
from tinygrad.device import Device, Buffer, Compiler
|
6
|
-
from tinygrad.ops import MemBuffer
|
7
7
|
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
8
|
-
from tinygrad.dtype import ImageDType
|
9
|
-
from tinygrad.codegen.
|
10
|
-
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
11
|
-
from tinygrad.codegen.uops import UOpGraph
|
8
|
+
from tinygrad.dtype import ImageDType, PtrDType
|
9
|
+
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
12
10
|
from tinygrad.tensor import Tensor
|
13
|
-
from tinygrad.shape.symbolic import sym_infer
|
14
11
|
from tinygrad.engine.realize import CompiledRunner
|
15
12
|
from tinygrad.renderer import Program
|
16
13
|
|
17
14
|
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
18
15
|
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
|
19
|
-
actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(
|
16
|
+
actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
|
20
17
|
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
21
18
|
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
|
22
19
|
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
|
23
|
-
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.
|
20
|
+
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)]
|
21
|
+
actions += [Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
|
24
22
|
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
23
|
+
actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
|
25
24
|
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
26
25
|
|
27
26
|
def _get_test_global_size(global_size, max_global_size, var_vals):
|
@@ -34,7 +33,8 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
|
|
34
33
|
break
|
35
34
|
return test_global_size, factor
|
36
35
|
|
37
|
-
def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None,
|
36
|
+
def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:List[Buffer], early_stop:Optional[float]=None,
|
37
|
+
max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> List[float]:
|
38
38
|
factor = 1
|
39
39
|
if p.global_size is not None and max_global_size is not None:
|
40
40
|
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
|
@@ -42,39 +42,39 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
|
|
42
42
|
try: car = CompiledRunner(p, precompiled=lib)
|
43
43
|
except AssertionError: return [math.inf] * cnt
|
44
44
|
tms = []
|
45
|
-
input_bufs = [rawbufs[i] for i
|
45
|
+
input_bufs = [rawbufs[i] for i in car.p.globals]
|
46
46
|
for _ in range(cnt):
|
47
47
|
if clear_l2:
|
48
|
-
|
48
|
+
if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
|
49
|
+
else:
|
50
|
+
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
49
51
|
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
50
|
-
if early_stop is not None and early_stop < tms
|
52
|
+
if early_stop is not None and early_stop < min(tms): break
|
51
53
|
return tms
|
52
54
|
|
53
55
|
class TimeoutException(Exception): pass
|
54
56
|
def timeout_handler(signum, frame): raise TimeoutException()
|
55
57
|
|
56
|
-
def _try_compile_linearized_w_idx(x:Tuple[int,
|
57
|
-
|
58
|
-
|
59
|
-
|
58
|
+
def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
|
59
|
+
if hasattr(signal, "alarm"):
|
60
|
+
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
|
61
|
+
# set timeout
|
62
|
+
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
|
63
|
+
ret = None
|
60
64
|
try:
|
61
|
-
x[1].
|
62
|
-
|
63
|
-
p
|
65
|
+
p = x[1].to_program(name_override="test")
|
66
|
+
assert p.uops is not None, "uop list wasn't generated?"
|
67
|
+
if len(p.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
|
64
68
|
st = time.perf_counter()
|
65
69
|
prog = compiler.compile(p.src)
|
66
70
|
et = time.perf_counter() - st
|
67
71
|
ret = (p, prog, et)
|
68
72
|
except RuntimeError:
|
69
73
|
if DEBUG >= 4: traceback.print_exc()
|
70
|
-
ret = None
|
71
|
-
except TimeoutException:
|
72
|
-
ret = None
|
73
74
|
except Exception as e:
|
74
75
|
if getenv("BEAM_STRICT_MODE"): raise e
|
75
|
-
ret = None
|
76
76
|
finally:
|
77
|
-
signal.alarm(0)
|
77
|
+
if hasattr(signal, "alarm"): signal.alarm(0)
|
78
78
|
return x[0], ret
|
79
79
|
|
80
80
|
# workers should ignore ctrl c
|
@@ -85,19 +85,22 @@ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_
|
|
85
85
|
# *** external API ***
|
86
86
|
|
87
87
|
# get (scrap) buffers for timing the linearizer
|
88
|
-
def bufs_from_lin(lin:
|
89
|
-
bufsts:DefaultDict[int, List[
|
90
|
-
for x in lin.
|
91
|
-
|
88
|
+
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
|
89
|
+
bufsts: DefaultDict[int, List[UOp]] = defaultdict(list)
|
90
|
+
for x in lin.bufs:
|
91
|
+
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
92
|
+
rawbufs: List[Optional[Buffer]] = [None]*len(bufsts)
|
92
93
|
for k,lx in bufsts.items():
|
93
|
-
buf_size = prod(
|
94
|
+
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
|
95
|
+
assert isinstance(dtype, (PtrDType, ImageDType))
|
94
96
|
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
|
95
|
-
|
97
|
+
buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
|
98
|
+
rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
|
96
99
|
assert all(r is not None for r in rawbufs)
|
97
100
|
return cast(List[Buffer], rawbufs)
|
98
101
|
|
99
102
|
# get dictionary of all possible actions
|
100
|
-
def
|
103
|
+
def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
|
101
104
|
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
102
105
|
for i,a in enumerate(actions):
|
103
106
|
if a.axis is not None and a.op is not OptOps.TC:
|
@@ -114,19 +117,19 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
|
|
114
117
|
except KernelOptError: pass
|
115
118
|
return acted_lins
|
116
119
|
|
117
|
-
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
118
|
-
def beam_search(lin:
|
120
|
+
beam_pool, BEAM_DEBUG, CAPTURE_BEAM = None, getenv("BEAM_DEBUG"), getenv("CAPTURE_BEAM", "")
|
121
|
+
def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel:
|
119
122
|
global beam_pool
|
120
|
-
key = {"ast": lin.ast
|
121
|
-
if not
|
123
|
+
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
124
|
+
if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
|
122
125
|
ret = lin.copy()
|
123
126
|
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
124
127
|
return ret
|
125
128
|
|
126
|
-
beam: List[Tuple[
|
129
|
+
beam: List[Tuple[Kernel, float]] = [(lin, float("inf"))]
|
127
130
|
seen_libs = set()
|
128
131
|
|
129
|
-
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
|
132
|
+
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
|
130
133
|
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
131
134
|
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
132
135
|
|
@@ -136,23 +139,31 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
|
|
136
139
|
|
137
140
|
try:
|
138
141
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
139
|
-
var_vals = {k:(k.
|
142
|
+
var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
140
143
|
exiting, st = False, time.perf_counter()
|
141
144
|
dev = Device[lin.opts.device]
|
142
145
|
while not exiting:
|
143
|
-
acted_lins: List[
|
144
|
-
timed_lins: List[Tuple[
|
146
|
+
acted_lins: List[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
|
147
|
+
timed_lins: List[Tuple[Kernel, float]] = []
|
145
148
|
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
149
|
+
least_compute_ops = math.inf
|
146
150
|
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
147
151
|
if proc is None: continue
|
148
152
|
p, lib, compile_et = proc
|
149
153
|
if lib in seen_libs: continue
|
150
|
-
#
|
154
|
+
# filter out kernels that use 1000x more compute than the smallest
|
155
|
+
least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops)
|
156
|
+
if least_compute_ops*1000 < this_compute_ops: continue
|
157
|
+
if len(CAPTURE_BEAM) > 0:
|
158
|
+
with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n")
|
151
159
|
seen_libs.add(lib)
|
152
|
-
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
153
|
-
except RuntimeError
|
160
|
+
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
|
161
|
+
except RuntimeError as e:
|
162
|
+
if len(CAPTURE_BEAM) > 0:
|
163
|
+
with open(CAPTURE_BEAM, 'a') as f: f.write("# Upper ast finished with an error:" + str(e).replace('\n',' ')+ "\n")
|
164
|
+
continue # for runtime issues
|
154
165
|
timed_lins.append((acted_lins[i], min(tms)))
|
155
|
-
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(
|
166
|
+
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(List, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
156
167
|
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
157
168
|
|
158
169
|
# done
|
@@ -181,8 +192,8 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
|
|
181
192
|
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
182
193
|
return ret[1]
|
183
194
|
|
184
|
-
def time_linearizer(lin:
|
185
|
-
key = {"ast": lin.ast
|
195
|
+
def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
196
|
+
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
|
186
197
|
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
187
198
|
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
188
199
|
|
@@ -190,7 +201,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
|
190
201
|
assert dev.compiler is not None
|
191
202
|
|
192
203
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
193
|
-
var_vals = {k:(k.
|
204
|
+
var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
194
205
|
p = lin.to_program()
|
195
206
|
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
196
207
|
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
tinygrad/function.py
CHANGED
@@ -3,10 +3,9 @@ import math
|
|
3
3
|
from typing import Tuple, Optional
|
4
4
|
from tinygrad.helpers import argsort
|
5
5
|
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
6
|
-
from tinygrad.ops import
|
6
|
+
from tinygrad.ops import Ops, resolve, sint
|
7
7
|
from tinygrad.tensor import Function
|
8
|
-
from tinygrad.lazy import LazyBuffer
|
9
|
-
from tinygrad.shape.symbolic import sint
|
8
|
+
from tinygrad.engine.lazy import LazyBuffer
|
10
9
|
|
11
10
|
class Contiguous(Function):
|
12
11
|
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
@@ -19,95 +18,96 @@ class ContiguousBackward(Function):
|
|
19
18
|
class Cast(Function):
|
20
19
|
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
|
21
20
|
self.input_dtype, self.bitcast = x.dtype, bitcast
|
22
|
-
return x.
|
21
|
+
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
|
23
22
|
|
24
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
23
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
24
|
+
if self.bitcast: raise RuntimeError("bitcast cannot backward")
|
25
|
+
return grad_output.cast(self.input_dtype)
|
25
26
|
|
26
27
|
# ************* unary ops *************
|
27
28
|
|
28
|
-
class Neg(Function):
|
29
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
|
30
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
|
31
|
-
|
32
29
|
class Reciprocal(Function):
|
33
30
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
34
|
-
self.ret = x.
|
31
|
+
self.ret = x.reciprocal()
|
35
32
|
return self.ret
|
36
|
-
|
37
|
-
|
33
|
+
|
34
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return -grad_output * self.ret * self.ret
|
38
35
|
|
39
36
|
class Sin(Function):
|
40
37
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
41
38
|
self.x = x
|
42
|
-
return x.
|
39
|
+
return x.sin()
|
43
40
|
|
44
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
45
|
-
return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
|
41
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return (math.pi/2 - self.x).sin() * grad_output
|
46
42
|
|
47
|
-
# NOTE: maximum(x, 0) behaves differently where x=0
|
48
43
|
class Relu(Function):
|
49
44
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
50
|
-
self.ret = x.
|
45
|
+
self.ret = x.maximum(0)
|
51
46
|
return self.ret
|
52
47
|
|
53
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
54
|
-
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
|
48
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.gt(0).cast(grad_output.dtype) * grad_output
|
55
49
|
|
56
50
|
class Log(Function):
|
57
51
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
58
52
|
self.x = x
|
59
|
-
return x.
|
53
|
+
return x.log2() * math.log(2)
|
60
54
|
|
61
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
|
55
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / self.x
|
62
56
|
|
63
57
|
class Exp(Function):
|
64
58
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
65
|
-
self.ret = x
|
59
|
+
self.ret = (x * (1/math.log(2))).exp2()
|
66
60
|
return self.ret
|
67
61
|
|
68
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret
|
62
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret * grad_output
|
69
63
|
|
70
64
|
class Sqrt(Function):
|
71
65
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
72
|
-
self.ret = x.
|
66
|
+
self.ret = x.sqrt()
|
73
67
|
return self.ret
|
74
68
|
|
75
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
76
|
-
return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
|
69
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / (self.ret*2)
|
77
70
|
|
78
71
|
# NOTE: the implicit derivative of sigmoid is not stable
|
79
72
|
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
80
73
|
# TODO: have the backend automatically find this
|
81
74
|
class Sigmoid(Function):
|
82
75
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
83
|
-
self.ret =
|
76
|
+
self.ret = (1 + (x * (-1/math.log(2))).exp2()).reciprocal()
|
84
77
|
return self.ret
|
85
78
|
|
86
79
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
87
|
-
return self.ret
|
80
|
+
return (self.ret * (1 - self.ret)) * grad_output
|
88
81
|
|
89
82
|
class Sign(Function):
|
90
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
91
|
-
return x.e(BinaryOps.CMPNE, x.const(0)).e(
|
92
|
-
TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
|
83
|
+
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
|
93
84
|
# backward always return 0 to match torch
|
94
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.
|
85
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const_like(0)
|
95
86
|
|
96
87
|
# ************* binary ops *************
|
97
88
|
|
98
89
|
class Less(Function):
|
99
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.
|
90
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.lt(y)
|
100
91
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
101
92
|
|
102
93
|
class Neq(Function):
|
103
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.
|
94
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.ne(y)
|
104
95
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
105
96
|
|
106
97
|
class Xor(Function):
|
107
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x
|
98
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x^y
|
99
|
+
|
100
|
+
class BitwiseAnd(Function):
|
101
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x&y
|
102
|
+
|
103
|
+
class BitwiseOr(Function):
|
104
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x|y
|
105
|
+
|
106
|
+
class Threefry(Function):
|
107
|
+
def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.threefry(seed)
|
108
108
|
|
109
109
|
class Add(Function):
|
110
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x
|
110
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x+y
|
111
111
|
|
112
112
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
113
113
|
return grad_output if self.needs_input_grad[0] else None, \
|
@@ -116,64 +116,65 @@ class Add(Function):
|
|
116
116
|
class Mul(Function):
|
117
117
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
118
118
|
self.x, self.y = x, y
|
119
|
-
return x
|
119
|
+
return x * y
|
120
120
|
|
121
121
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
122
|
-
return self.y
|
123
|
-
self.x
|
122
|
+
return (self.y * grad_output) if self.needs_input_grad[0] else None, \
|
123
|
+
(self.x * grad_output) if self.needs_input_grad[1] else None
|
124
124
|
|
125
|
-
class
|
126
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
127
|
-
self.x, self.y = x, y
|
128
|
-
return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
|
129
|
-
|
130
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
131
|
-
return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
|
132
|
-
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
|
125
|
+
class IDiv(Function):
|
126
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x // y
|
133
127
|
|
134
128
|
# ************* ternary ops *************
|
135
129
|
|
136
130
|
class Where(Function):
|
137
131
|
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
|
138
132
|
self.x = x
|
139
|
-
return self.x.
|
133
|
+
return self.x.where(y, z)
|
140
134
|
|
141
135
|
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
|
142
136
|
return None, \
|
143
|
-
self.x.
|
144
|
-
self.x.
|
137
|
+
self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
|
138
|
+
self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
|
145
139
|
|
146
140
|
# ************* reduce ops *************
|
147
141
|
|
148
142
|
class Sum(Function):
|
149
143
|
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
150
144
|
self.input_shape = x.shape
|
151
|
-
return x.r(
|
145
|
+
return x.r(Ops.ADD, axis)
|
152
146
|
|
153
147
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
154
148
|
|
149
|
+
class Prod(Function):
|
150
|
+
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
151
|
+
self.x, self.ret = x, x.r(Ops.MUL, axis)
|
152
|
+
return self.ret
|
153
|
+
|
154
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
155
|
+
return (grad_output * self.ret).expand(self.x.shape) / self.x
|
156
|
+
|
155
157
|
class Max(Function):
|
156
158
|
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
157
|
-
self.x, self.ret, self.axis = x, x.r(
|
159
|
+
self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
|
158
160
|
return self.ret
|
159
161
|
|
160
162
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
161
163
|
# 1s in locations where the max was chosen (can be two locations)
|
162
|
-
max_is_1s = self.x.
|
163
|
-
|
164
|
-
div
|
165
|
-
return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
164
|
+
max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
|
165
|
+
div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
|
166
|
+
return (max_is_1s/div) * grad_output.expand(self.x.shape)
|
166
167
|
|
167
168
|
# ************* movement ops *************
|
168
169
|
|
169
170
|
# NOTE: this is sum in reverse
|
170
171
|
class Expand(Function):
|
171
172
|
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
172
|
-
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
|
173
|
+
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
|
173
174
|
return x.expand(shape)
|
174
175
|
|
175
176
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
176
|
-
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(
|
177
|
+
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
|
177
178
|
|
178
179
|
class Reshape(Function):
|
179
180
|
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|