tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py CHANGED
@@ -2,26 +2,25 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
2
2
  import itertools, functools, random, math, time, multiprocessing, traceback, signal
3
3
  from collections import defaultdict
4
4
  from dataclasses import replace
5
+ from tinygrad.ops import UOp, Ops, Variable, sym_infer
5
6
  from tinygrad.device import Device, Buffer, Compiler
6
- from tinygrad.ops import MemBuffer
7
7
  from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
8
- from tinygrad.dtype import ImageDType
9
- from tinygrad.codegen.linearizer import Linearizer
10
- from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
11
- from tinygrad.codegen.uops import UOpGraph
8
+ from tinygrad.dtype import ImageDType, PtrDType
9
+ from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
12
10
  from tinygrad.tensor import Tensor
13
- from tinygrad.shape.symbolic import sym_infer
14
11
  from tinygrad.engine.realize import CompiledRunner
15
12
  from tinygrad.renderer import Program
16
13
 
17
14
  actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
18
15
  actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
19
- actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
16
+ actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
20
17
  actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
21
18
  actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
22
19
  if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
23
- actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
20
+ actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)]
21
+ actions += [Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
24
22
  actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
23
+ actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
25
24
  if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
26
25
 
27
26
  def _get_test_global_size(global_size, max_global_size, var_vals):
@@ -34,7 +33,8 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
34
33
  break
35
34
  return test_global_size, factor
36
35
 
37
- def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
36
+ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:List[Buffer], early_stop:Optional[float]=None,
37
+ max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> List[float]:
38
38
  factor = 1
39
39
  if p.global_size is not None and max_global_size is not None:
40
40
  global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
@@ -42,39 +42,39 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
42
42
  try: car = CompiledRunner(p, precompiled=lib)
43
43
  except AssertionError: return [math.inf] * cnt
44
44
  tms = []
45
- input_bufs = [rawbufs[i] for i,_ in car.p.globals]
45
+ input_bufs = [rawbufs[i] for i in car.p.globals]
46
46
  for _ in range(cnt):
47
47
  if clear_l2:
48
- with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
48
+ if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
49
+ else:
50
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
49
51
  tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
50
- if early_stop is not None and early_stop < tms[-1]: break
52
+ if early_stop is not None and early_stop < min(tms): break
51
53
  return tms
52
54
 
53
55
  class TimeoutException(Exception): pass
54
56
  def timeout_handler(signum, frame): raise TimeoutException()
55
57
 
56
- def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
57
- signal.signal(signal.SIGALRM, timeout_handler)
58
- # set timeout
59
- signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
58
+ def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
59
+ if hasattr(signal, "alarm"):
60
+ signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
61
+ # set timeout
62
+ signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
63
+ ret = None
60
64
  try:
61
- x[1].linearize()
62
- if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
63
- p = x[1].to_program()
65
+ p = x[1].to_program(name_override="test")
66
+ assert p.uops is not None, "uop list wasn't generated?"
67
+ if len(p.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
64
68
  st = time.perf_counter()
65
69
  prog = compiler.compile(p.src)
66
70
  et = time.perf_counter() - st
67
71
  ret = (p, prog, et)
68
72
  except RuntimeError:
69
73
  if DEBUG >= 4: traceback.print_exc()
70
- ret = None
71
- except TimeoutException:
72
- ret = None
73
74
  except Exception as e:
74
75
  if getenv("BEAM_STRICT_MODE"): raise e
75
- ret = None
76
76
  finally:
77
- signal.alarm(0)
77
+ if hasattr(signal, "alarm"): signal.alarm(0)
78
78
  return x[0], ret
79
79
 
80
80
  # workers should ignore ctrl c
@@ -85,19 +85,22 @@ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_
85
85
  # *** external API ***
86
86
 
87
87
  # get (scrap) buffers for timing the linearizer
88
- def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
89
- bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
90
- for x in lin.membufs: bufsts[x.idx].append(x)
91
- rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
88
+ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
89
+ bufsts: DefaultDict[int, List[UOp]] = defaultdict(list)
90
+ for x in lin.bufs:
91
+ if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
92
+ rawbufs: List[Optional[Buffer]] = [None]*len(bufsts)
92
93
  for k,lx in bufsts.items():
93
- buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
94
+ buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
95
+ assert isinstance(dtype, (PtrDType, ImageDType))
94
96
  if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
95
- rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, lx[0].dtype)
97
+ buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
98
+ rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
96
99
  assert all(r is not None for r in rawbufs)
97
100
  return cast(List[Buffer], rawbufs)
98
101
 
99
102
  # get dictionary of all possible actions
100
- def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
103
+ def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
101
104
  acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
102
105
  for i,a in enumerate(actions):
103
106
  if a.axis is not None and a.op is not OptOps.TC:
@@ -114,19 +117,19 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
114
117
  except KernelOptError: pass
115
118
  return acted_lins
116
119
 
117
- beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
118
- def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
120
+ beam_pool, BEAM_DEBUG, CAPTURE_BEAM = None, getenv("BEAM_DEBUG"), getenv("CAPTURE_BEAM", "")
121
+ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel:
119
122
  global beam_pool
120
- key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
121
- if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
123
+ key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
124
+ if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
122
125
  ret = lin.copy()
123
126
  for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
124
127
  return ret
125
128
 
126
- beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
129
+ beam: List[Tuple[Kernel, float]] = [(lin, float("inf"))]
127
130
  seen_libs = set()
128
131
 
129
- default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
132
+ default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
130
133
  if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
131
134
  beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
132
135
 
@@ -136,23 +139,31 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
136
139
 
137
140
  try:
138
141
  rawbufs = _ensure_buffer_alloc(rawbufs)
139
- var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
142
+ var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
140
143
  exiting, st = False, time.perf_counter()
141
144
  dev = Device[lin.opts.device]
142
145
  while not exiting:
143
- acted_lins: List[Linearizer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
144
- timed_lins: List[Tuple[Linearizer, float]] = []
146
+ acted_lins: List[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
147
+ timed_lins: List[Tuple[Kernel, float]] = []
145
148
  _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
149
+ least_compute_ops = math.inf
146
150
  for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
147
151
  if proc is None: continue
148
152
  p, lib, compile_et = proc
149
153
  if lib in seen_libs: continue
150
- #print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
154
+ # filter out kernels that use 1000x more compute than the smallest
155
+ least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops)
156
+ if least_compute_ops*1000 < this_compute_ops: continue
157
+ if len(CAPTURE_BEAM) > 0:
158
+ with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n")
151
159
  seen_libs.add(lib)
152
- try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
153
- except RuntimeError: continue # for runtime issues
160
+ try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
161
+ except RuntimeError as e:
162
+ if len(CAPTURE_BEAM) > 0:
163
+ with open(CAPTURE_BEAM, 'a') as f: f.write("# Upper ast finished with an error:" + str(e).replace('\n',' ')+ "\n")
164
+ continue # for runtime issues
154
165
  timed_lins.append((acted_lins[i], min(tms)))
155
- if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
166
+ if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(List, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
156
167
  elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
157
168
 
158
169
  # done
@@ -181,8 +192,8 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
181
192
  assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
182
193
  return ret[1]
183
194
 
184
- def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
185
- key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
195
+ def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
196
+ key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
186
197
  "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
187
198
  if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
188
199
 
@@ -190,7 +201,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
190
201
  assert dev.compiler is not None
191
202
 
192
203
  rawbufs = _ensure_buffer_alloc(rawbufs)
193
- var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
204
+ var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
194
205
  p = lin.to_program()
195
206
  tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
196
207
  max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
tinygrad/function.py CHANGED
@@ -3,10 +3,9 @@ import math
3
3
  from typing import Tuple, Optional
4
4
  from tinygrad.helpers import argsort
5
5
  from tinygrad.dtype import dtypes, DType, sum_acc_dtype
6
- from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
6
+ from tinygrad.ops import Ops, resolve, sint
7
7
  from tinygrad.tensor import Function
8
- from tinygrad.lazy import LazyBuffer
9
- from tinygrad.shape.symbolic import sint
8
+ from tinygrad.engine.lazy import LazyBuffer
10
9
 
11
10
  class Contiguous(Function):
12
11
  def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
@@ -19,95 +18,96 @@ class ContiguousBackward(Function):
19
18
  class Cast(Function):
20
19
  def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
21
20
  self.input_dtype, self.bitcast = x.dtype, bitcast
22
- return x.cast(dtype, bitcast)
21
+ return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
23
22
 
24
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast)
23
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
24
+ if self.bitcast: raise RuntimeError("bitcast cannot backward")
25
+ return grad_output.cast(self.input_dtype)
25
26
 
26
27
  # ************* unary ops *************
27
28
 
28
- class Neg(Function):
29
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
30
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
31
-
32
29
  class Reciprocal(Function):
33
30
  def forward(self, x:LazyBuffer) -> LazyBuffer:
34
- self.ret = x.e(UnaryOps.RECIP)
31
+ self.ret = x.reciprocal()
35
32
  return self.ret
36
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
37
- return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
33
+
34
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return -grad_output * self.ret * self.ret
38
35
 
39
36
  class Sin(Function):
40
37
  def forward(self, x:LazyBuffer) -> LazyBuffer:
41
38
  self.x = x
42
- return x.e(UnaryOps.SIN)
39
+ return x.sin()
43
40
 
44
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
45
- return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
41
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return (math.pi/2 - self.x).sin() * grad_output
46
42
 
47
- # NOTE: maximum(x, 0) behaves differently where x=0
48
43
  class Relu(Function):
49
44
  def forward(self, x:LazyBuffer) -> LazyBuffer:
50
- self.ret = x.e(BinaryOps.MAX, x.const(0))
45
+ self.ret = x.maximum(0)
51
46
  return self.ret
52
47
 
53
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
54
- return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
48
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.gt(0).cast(grad_output.dtype) * grad_output
55
49
 
56
50
  class Log(Function):
57
51
  def forward(self, x:LazyBuffer) -> LazyBuffer:
58
52
  self.x = x
59
- return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
53
+ return x.log2() * math.log(2)
60
54
 
61
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
55
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / self.x
62
56
 
63
57
  class Exp(Function):
64
58
  def forward(self, x:LazyBuffer) -> LazyBuffer:
65
- self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
59
+ self.ret = (x * (1/math.log(2))).exp2()
66
60
  return self.ret
67
61
 
68
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
62
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret * grad_output
69
63
 
70
64
  class Sqrt(Function):
71
65
  def forward(self, x:LazyBuffer) -> LazyBuffer:
72
- self.ret = x.e(UnaryOps.SQRT)
66
+ self.ret = x.sqrt()
73
67
  return self.ret
74
68
 
75
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
76
- return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
69
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / (self.ret*2)
77
70
 
78
71
  # NOTE: the implicit derivative of sigmoid is not stable
79
72
  # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
80
73
  # TODO: have the backend automatically find this
81
74
  class Sigmoid(Function):
82
75
  def forward(self, x:LazyBuffer) -> LazyBuffer:
83
- self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
76
+ self.ret = (1 + (x * (-1/math.log(2))).exp2()).reciprocal()
84
77
  return self.ret
85
78
 
86
79
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
87
- return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.ADD, self.ret.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
80
+ return (self.ret * (1 - self.ret)) * grad_output
88
81
 
89
82
  class Sign(Function):
90
- def forward(self, x:LazyBuffer) -> LazyBuffer:
91
- return x.e(BinaryOps.CMPNE, x.const(0)).e(
92
- TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
83
+ def forward(self, x:LazyBuffer) -> LazyBuffer: return x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
93
84
  # backward always return 0 to match torch
94
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
85
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const_like(0)
95
86
 
96
87
  # ************* binary ops *************
97
88
 
98
89
  class Less(Function):
99
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
90
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.lt(y)
100
91
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
101
92
 
102
93
  class Neq(Function):
103
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
94
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.ne(y)
104
95
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
105
96
 
106
97
  class Xor(Function):
107
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
98
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x^y
99
+
100
+ class BitwiseAnd(Function):
101
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x&y
102
+
103
+ class BitwiseOr(Function):
104
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x|y
105
+
106
+ class Threefry(Function):
107
+ def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.threefry(seed)
108
108
 
109
109
  class Add(Function):
110
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
110
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x+y
111
111
 
112
112
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
113
113
  return grad_output if self.needs_input_grad[0] else None, \
@@ -116,64 +116,65 @@ class Add(Function):
116
116
  class Mul(Function):
117
117
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
118
118
  self.x, self.y = x, y
119
- return x.e(BinaryOps.MUL, y)
119
+ return x * y
120
120
 
121
121
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
122
- return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
123
- self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
122
+ return (self.y * grad_output) if self.needs_input_grad[0] else None, \
123
+ (self.x * grad_output) if self.needs_input_grad[1] else None
124
124
 
125
- class Div(Function):
126
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
127
- self.x, self.y = x, y
128
- return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
129
-
130
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
131
- return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
132
- grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
125
+ class IDiv(Function):
126
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x // y
133
127
 
134
128
  # ************* ternary ops *************
135
129
 
136
130
  class Where(Function):
137
131
  def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
138
132
  self.x = x
139
- return self.x.e(TernaryOps.WHERE, y, z)
133
+ return self.x.where(y, z)
140
134
 
141
135
  def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
142
136
  return None, \
143
- self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
144
- self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
137
+ self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
138
+ self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
145
139
 
146
140
  # ************* reduce ops *************
147
141
 
148
142
  class Sum(Function):
149
143
  def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
150
144
  self.input_shape = x.shape
151
- return x.r(ReduceOps.SUM, axis)
145
+ return x.r(Ops.ADD, axis)
152
146
 
153
147
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
154
148
 
149
+ class Prod(Function):
150
+ def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
151
+ self.x, self.ret = x, x.r(Ops.MUL, axis)
152
+ return self.ret
153
+
154
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
155
+ return (grad_output * self.ret).expand(self.x.shape) / self.x
156
+
155
157
  class Max(Function):
156
158
  def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
157
- self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
159
+ self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
158
160
  return self.ret
159
161
 
160
162
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
161
163
  # 1s in locations where the max was chosen (can be two locations)
162
- max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.ADD, self.x.e(BinaryOps.CMPNE, \
163
- self.ret.expand(self.x.shape)).cast(dtypes.float).e(UnaryOps.NEG))
164
- div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
165
- return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
164
+ max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
165
+ div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
166
+ return (max_is_1s/div) * grad_output.expand(self.x.shape)
166
167
 
167
168
  # ************* movement ops *************
168
169
 
169
170
  # NOTE: this is sum in reverse
170
171
  class Expand(Function):
171
172
  def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
172
- self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
173
+ self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
173
174
  return x.expand(shape)
174
175
 
175
176
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
176
- return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
177
+ return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
177
178
 
178
179
  class Reshape(Function):
179
180
  def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: