tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,44 +1,45 @@
1
- from typing import cast, Optional, Callable
1
+ from typing import cast, Callable
2
2
  import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
3
3
  from collections import defaultdict
4
4
  from dataclasses import replace
5
- from tinygrad.ops import UOp, Ops, Variable, sym_infer
5
+ from tinygrad.uop.ops import UOp, Ops, Variable, sym_infer, AxisType
6
6
  from tinygrad.device import Device, Buffer, Compiler
7
7
  from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
8
8
  from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
9
9
  from tinygrad.dtype import ImageDType, PtrDType
10
- from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
10
+ from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError
11
11
  from tinygrad.tensor import Tensor
12
- from tinygrad.engine.realize import CompiledRunner
12
+ from tinygrad.engine.realize import CompiledRunner, get_program
13
13
  from tinygrad.renderer import ProgramSpec
14
14
 
15
- actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
15
+ actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
16
16
  actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
17
17
  actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
18
18
  actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
19
19
  actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
20
20
  if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
21
21
  actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
22
- actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
23
- actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
22
+ actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0, getenv("TC", 1)))]
23
+ # covers resnet kernels (3 global * 3 reduce)
24
+ actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("TC", 1))) for axis in range(9)]
24
25
  actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
25
26
  if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
26
27
 
27
- def _get_test_global_size(global_size, max_global_size, var_vals):
28
- test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
28
+ def get_test_global_size(global_size, max_global_size, var_vals):
29
+ test_global_size = [sym_infer(sz, var_vals) for sz in global_size]
30
+ input_size = prod(test_global_size)
29
31
  while prod(test_global_size) > max_global_size:
30
32
  for j in range(len(global_size)-1,-1,-1):
31
33
  if test_global_size[j] > 16:
32
34
  test_global_size[j] //= 2
33
- factor *= 2
34
35
  break
35
- return test_global_size, factor
36
+ return test_global_size, input_size / prod(test_global_size)
36
37
 
37
- def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None,
38
- max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
38
+ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:float|None=None,
39
+ allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
39
40
  factor = 1
40
- if p.global_size is not None and max_global_size is not None:
41
- global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
41
+ if allow_test_size and p.global_size is not None and max_global_size is not None:
42
+ global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
42
43
  p = replace(p, global_size=global_size)
43
44
  try: car = CompiledRunner(p, precompiled=lib)
44
45
  except AssertionError: return [math.inf] * cnt
@@ -56,16 +57,18 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbuf
56
57
  class TimeoutException(Exception): pass
57
58
  def timeout_handler(signum, frame): raise TimeoutException()
58
59
 
59
- def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, Optional[tuple[ProgramSpec, bytes, float]]]:
60
+ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]:
60
61
  if hasattr(signal, "alarm"):
61
62
  signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
62
63
  # set timeout
63
64
  signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
64
65
  ret = None
65
66
  try:
66
- p = x[1].to_program(name_override="test")
67
+ p = get_program(x[1].copy().get_optimized_ast(name_override="test"), x[1].opts)
67
68
  assert p.uops is not None, "uop list wasn't generated?"
68
- if len(p.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
69
+ if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
70
+ if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
71
+ raise RuntimeError("too many uops")
69
72
  st = time.perf_counter()
70
73
  prog = compiler.compile(p.src)
71
74
  et = time.perf_counter() - st
@@ -78,10 +81,12 @@ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tup
78
81
  if hasattr(signal, "alarm"): signal.alarm(0)
79
82
  return x[0], ret
80
83
 
81
- # workers should ignore ctrl c
82
- def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
84
+ # workers should not open devices and should ignore ctrl c and should not launch VIZ
85
+ def _init_worker():
86
+ Context(ALLOW_DEVICE_USAGE=0, VIZ=0, TRACK_MATCH_STATS=0).__enter__()
87
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
83
88
 
84
- def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
89
+ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() if buf is not None else buf for buf in bufs]
85
90
 
86
91
  # *** external API ***
87
92
 
@@ -89,39 +94,46 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_
89
94
  def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
90
95
  bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
91
96
  for x in lin.bufs:
92
- if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
93
- rawbufs: list[Optional[Buffer]] = [None]*len(bufsts)
97
+ if x.src[0].base.op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].base.arg].append(x)
98
+ # TODO: Nones are staying in here if buffers are optimized out!
99
+ # TODO: add a test for this
100
+ rawbufs: list[Buffer|None] = [None]*(max(bufsts)+1)
94
101
  for k,lx in bufsts.items():
95
102
  buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
96
103
  assert isinstance(dtype, (PtrDType, ImageDType))
97
104
  if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
98
105
  buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
99
106
  rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
100
- assert all(r is not None for r in rawbufs)
107
+ #assert all(r is not None for r in rawbufs)
101
108
  return cast(list[Buffer], rawbufs)
102
109
 
103
110
  # get dictionary of all possible actions
104
- def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
111
+ def get_kernel_actions(lin:Kernel, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Kernel]:
105
112
  acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
106
- kernel_actions = actions.copy()
113
+ kernel_actions = (actions if candidates is None else candidates).copy()
107
114
 
108
115
  if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
109
116
  for i, action in enumerate(kernel_actions):
110
117
  if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
111
118
  # replace every tc_action with default tc with one tc_action for each available tc
112
- kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
119
+ kernel_actions[i:i+1] = \
120
+ [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1], tc_arg[2])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
113
121
 
114
122
  for i,a in enumerate(kernel_actions):
115
123
  if a.axis is not None and a.op is not OptOps.TC:
116
- if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
124
+ try: ax = lin.real_axis(a.op, a.axis)
125
+ except KernelOptError: continue
126
+ if (ax >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, a.axis, 0) in kernel_actions): continue
117
127
  lin2 = lin.copy()
118
128
  try:
119
129
  lin2.apply_opt(a)
120
130
  up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
121
- for s,c in zip(lin2.full_shape, lin2.colors()):
122
- if c in {"magenta", "yellow"}: up *= s
123
- elif c in {"cyan", "green", "white"}: lcl *= s
124
- if up//tc_up > max_up or lcl > max_lcl: continue
131
+ for s,c in zip(lin2.full_shape, lin2.axis_types):
132
+ if c in (AxisType.UPCAST, AxisType.UNROLL): up *= s
133
+ elif c in (AxisType.LOCAL, AxisType.GROUP_REDUCE): lcl *= s
134
+ if up//tc_up > max_up or lcl > max_lcl:
135
+ if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}")
136
+ continue
125
137
  acted_lins[i+1] = lin2
126
138
  except KernelOptError: pass
127
139
  return acted_lins
@@ -138,7 +150,7 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
138
150
  beam: list[tuple[Kernel, float]] = [(lin, float("inf"))]
139
151
  seen_libs = set()
140
152
 
141
- default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
153
+ default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0
142
154
  if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
143
155
  beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
144
156
  @atexit.register
@@ -166,8 +178,12 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
166
178
  least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
167
179
  if least_compute_ops*1000 < this_compute_ops: continue
168
180
  seen_libs.add(lib)
169
- try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
170
- except RuntimeError: continue # for runtime issues
181
+ try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0,
182
+ allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'))
183
+ except Exception as e:
184
+ if BEAM_DEBUG: print(f"BEAM failed for opts: {acted_lins[i].applied_opts}\n{e}")
185
+ if isinstance(e, RuntimeError): continue
186
+ raise
171
187
  timed_lins.append((acted_lins[i], min(tms)))
172
188
  if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
173
189
  elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
@@ -0,0 +1,134 @@
1
+ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
2
+ from tinygrad.helpers import all_same, prod, unwrap, colored
3
+ from tinygrad.shape.shapetracker import ShapeTracker
4
+ from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
5
+ from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
6
+ from tinygrad.dtype import ImageDType, dtypes
7
+
8
+ merge_views = PatternMatcher([
9
+ # merge adjacent views
10
+ (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
11
+ # replace MovementOps with VIEW
12
+ (UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
13
+ # remove NOOP views
14
+ (UPat.var("x").view(name="view"),
15
+ lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
16
+ (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
17
+ lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
18
+ # only unmaksed VIEW on CONST replaces the ShapeTracker
19
+ (UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
20
+ lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
21
+ ])
22
+
23
+ def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
24
+ # contiguous, expand, and the same with ones removed
25
+ if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
26
+ tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
27
+ new_shape: list[sint] = []
28
+ new_reduce_axis = []
29
+ if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
30
+ for i,pairs in enumerate(contraction):
31
+ new_shape_chunk = [view.shape[p] for p in pairs]
32
+ if i in r.arg[1]:
33
+ # if this is a reduce axis, we need a 1 in the view here to put it
34
+ assert len(new_shape_chunk) > 0
35
+ new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
36
+ new_reduce_axis.append(len(new_shape)-1)
37
+ else:
38
+ # otherwise, pass through the new_shape_chunk
39
+ new_shape += new_shape_chunk
40
+ ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
41
+ assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
42
+ return ret
43
+ return None
44
+
45
+ view_left = merge_views+PatternMatcher([
46
+ # view before elementwise and buffer ops
47
+ (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
48
+ lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
49
+ # if there's ones added after reduce, put this before the reduce
50
+ (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
51
+ ])
52
+
53
+ view_left_through_load = PatternMatcher([
54
+ # view before load
55
+ (UPat(Ops.VIEW, src=(UPat(Ops.LOAD, name="e"),), name="view"),
56
+ lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
57
+ ])
58
+
59
+ def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
60
+
61
+ # change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
62
+ def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
63
+ # contiguous and same size can push to children
64
+ # if there's a reduce child, shapes match with ones removed
65
+ if unwrap(view.st).contiguous and view.size == r.size and \
66
+ (not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
67
+ tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
68
+ return None
69
+ # swizzle the input
70
+ input_st = ShapeTracker.from_shape(src.shape)
71
+ tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
72
+ prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
73
+ strides = strides_for_shape(rshape)
74
+ nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
75
+ v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
76
+ new_view = tmp + ShapeTracker(tuple(nv))
77
+ swizzled_input = apply_swizzle(src.view(new_view))
78
+ # create a new reduceop
79
+ new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
80
+ if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
81
+ else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
82
+ return red.reshape(view.shape)
83
+
84
+ def reduceop_view_right(src:UOp, v:UOp, r:UOp):
85
+ assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
86
+ new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
87
+ return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
88
+
89
+ def elementwise_view_right(root:UOp):
90
+ if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
91
+ assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
92
+ # place view after applying the elementwise op
93
+ new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
94
+ new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
95
+ # reshape to match downstream shapes
96
+ return root.replace(src=tuple(new_src)).reshape(root.shape)
97
+
98
+ # push VIEW to children
99
+ view_right = merge_views+PatternMatcher([
100
+ # push a non contiguous ShapeTracker through reduceop
101
+ (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
102
+ # apply view after reduceops
103
+ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
104
+ # apply view after elementwise ops
105
+ (UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
106
+ # merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
107
+ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
108
+ lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
109
+ # remove view from sink
110
+ (UPat(Ops.VIEW, name="v").sink(name="sink"), lambda v,sink: v.src[0].sink(arg=sink.arg)),
111
+ ])
112
+
113
+ def check_load_st(glbl:UOp, view:UOp):
114
+ if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
115
+ # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
116
+ if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
117
+ # if it has a single view and it's equal when you shrink a contig, it's fine
118
+ if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
119
+ # otherwise, it's not fine
120
+ raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
121
+ +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
122
+
123
+ fix_kernel_ops = view_left_through_load+PatternMatcher([
124
+ # add view to LOAD and STORE
125
+ (UPat(Ops.DEFINE_GLOBAL, name="g").load(), lambda g: g.view(g.st).load()),
126
+ (UPat(Ops.DEFINE_GLOBAL, name="g").store(UPat.var('x')), lambda g,x: g.view(g.st).store(x)),
127
+ # VALID
128
+ (UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
129
+ lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
130
+ # no ImageDType after index
131
+ (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
132
+ # if this kernel also assigns to the loaded buffer, ensure we can index it correctly
133
+ (UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
134
+ ])
@@ -0,0 +1,127 @@
1
+ import math, functools
2
+ from dataclasses import dataclass
3
+ from tinygrad.dtype import DType, dtypes
4
+ from tinygrad.helpers import getenv
5
+
6
+ @dataclass(frozen=True)
7
+ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
8
+ dims: tuple[int,int,int] # N, M, K
9
+ threads: int # number of threads that construct the warp
10
+ elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
11
+ dtype_in: DType # dtype for A and B
12
+ dtype_out: DType # dtype for C and D
13
+ opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifying kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
14
+ # (local_swizzle, upcast_swizzle, reduce_swizzle)
15
+ # l<num> is the num axis of the locals, similar for u<num> and upcasts, r<num> and reduces
16
+ swizzle: tuple[tuple[tuple[str, ...], tuple[str, ...], tuple[str, ...]], tuple[tuple[str, ...], tuple[str, ...], tuple[str, ...]]]
17
+ @functools.cache # pylint: disable=method-cache-max-size-none
18
+ def _remaps(self) -> list[dict[str, str]]:
19
+ local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
20
+ fwd_st = [f"l{i}" for i in range(local_axes)] + [f"u{i}" for i in range(upcast_axes)] + [f"r{i}" for i in range(reduce_axes)]
21
+ return [dict(zip(fwd_st, sum(s, ()))) for s in self.swizzle]
22
+ def permutes_for_shape_str(self, shape_str:list[str]) -> tuple[tuple[int, ...], tuple[int, ...]]:
23
+ ret = [[shape_str.index(remap[ss]) if ss in remap else i for i,ss in enumerate(shape_str)] for remap in self._remaps()]
24
+ return tuple(ret[0]), tuple(ret[1])
25
+ def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
26
+ def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
27
+ def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
28
+ def base_upcast_axes(self):
29
+ # this is defined in the swizzle. first we use the upcast axes, then the reduce
30
+ return ([f"r{i}" for i in range(len(self.get_reduce_axes()))] + [f"u{i}" for i in range(len(self.get_upcast_axes()))])[::-1]
31
+ def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
32
+ def __post_init__(self):
33
+ # all axes have size 2, <local> <reduce> <upcast> is the order
34
+ local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
35
+ assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), \
36
+ f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})"
37
+ assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
38
+ assert 2**upcast_axes == self.elements_per_thread[2], \
39
+ f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}"
40
+ # check dims match opts
41
+ assert self.dims[0] == 2**len(gd:=[x for x in self.opts if x[1] == '0']), f"opts wrong on dims[0], {self.dims[0]} vs {gd}"
42
+ assert self.dims[1] == 2**len(gd:=[x for x in self.opts if x[1] == '1']), f"opts wrong on dims[1], {self.dims[1]} vs {gd}"
43
+ # NOTE: the K opts is implictly set by the dim
44
+ # check swizzle
45
+ assert len(self.swizzle[0]) == 3 and len(self.swizzle[1]) == 3, "swizzle has wrong part count"
46
+ assert len(self.swizzle[0][0]) == len(self.swizzle[1][0]) == local_axes, "local swizzle size is wrong"
47
+ assert len(self.swizzle[0][1]) == len(self.swizzle[1][1]) == upcast_axes, "upcast swizzle size is wrong"
48
+ assert len(self.swizzle[0][2]) == len(self.swizzle[1][2]) == reduce_axes, "reduce swizzle size is wrong"
49
+ assert all(len(s) == local_axes+upcast_axes+reduce_axes for s in self._remaps()), "remaps are the wrong size"
50
+ # check elements_per_thread
51
+ un, ln = 0, 0
52
+ zero_stride_0 = []
53
+ zero_stride_1 = []
54
+ for o in self.opts:
55
+ if o[1] == '0': zero_stride_0.append(o[0] + str(un if o[0] == 'u' else ln))
56
+ if o[1] == '1': zero_stride_1.append(o[0] + str(un if o[0] == 'u' else ln))
57
+ if o[0] == 'u': un += 1
58
+ if o[0] == 'l': ln += 1
59
+ # NOTE: all the zero_stride dims can be placed in any order in the swizzle
60
+ upcasted_0 = [x for x in (self.swizzle[0][1] + self.swizzle[0][2]) if x not in zero_stride_0 and x[0] != 'l']
61
+ upcasted_1 = [x for x in (self.swizzle[1][1] + self.swizzle[1][2]) if x not in zero_stride_1 and x[0] != 'l']
62
+ assert 2**len(upcasted_0) == self.elements_per_thread[0], f"mismatch in elements_per_thread[0], {upcasted_0} vs {self.elements_per_thread[0]}"
63
+ assert 2**len(upcasted_1) == self.elements_per_thread[1], f"mismatch in elements_per_thread[1], {upcasted_1} vs {self.elements_per_thread[1]}"
64
+
65
+ # ***** NVIDIA *****
66
+
67
+ cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
68
+
69
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
70
+ cuda_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
71
+ swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('u1', 'r3'), ('l0', 'l1', 'u0', 'r0')),
72
+ (('r1', 'r2', 'u0', 'l0', 'l1'), ('r0', 'r3'), ('l2', 'l3', 'l4', 'u1'))))
73
+ for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float), (dtypes.half,dtypes.half)]]
74
+ cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
75
+ swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('r0', 'u1'), ('l0', 'l1', 'u0')),
76
+ (('r1', 'r2', 'u0', 'l0', 'l1'), ('u1', 'r0'), ('l2', 'l3', 'l4'))))
77
+ for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
78
+ cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
79
+ swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')),
80
+ (('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))]
81
+ cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16
82
+ if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32
83
+ cuda_sm75: list[TensorCore] = cuda_8168_f16
84
+
85
+ # ***** AMD *****
86
+
87
+ # https://gpuopen.com/learn/wmma_on_rdna3/
88
+ amd_rdna3 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
89
+ opts=("l0","l0","l0","l0","l1","u1","u1","u1"),
90
+ swizzle=((('l4', 'u0', 'u1', 'u2', 'l0'), ('r1', 'r2', 'r3'), ('l1', 'l2', 'l3', 'r0')),
91
+ (('l0', 'l1', 'l2', 'l3', 'l4'), ('r1', 'r2', 'r3'), ('u0', 'u1', 'u2', 'r0'))))
92
+ for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float)]]
93
+ amd_rdna4 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(8,8,8), dtype_in=di, dtype_out=do,
94
+ opts=("l0","l0","l0","l0","u1","u1","u1","l1"),
95
+ swizzle=((('u0', 'u1', 'u2', 'l4', 'r2'), ('r0', 'r1', 'r3'), ('l0', 'l1', 'l2', 'l3')),
96
+ (('l0', 'l1', 'l2', 'l3', 'r2'), ('r0', 'r1', 'r3'), ('l4', 'u0', 'u1', 'u2'))))
97
+ for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
98
+
99
+ # https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme
100
+ amd_cdna = [TensorCore(dims=(16,16,16), threads=64, elements_per_thread=(4,4,4), dtype_in=di, dtype_out=do,
101
+ opts=("l0","l0","l0","l0","u1","u1","l1","l1"),
102
+ swizzle=((('u0', 'u1', 'l4', 'l5', 'r2', 'r3'), ('r0', 'r1'), ('l0', 'l1', 'l2', 'l3')),
103
+ (('l0', 'l1', 'l2', 'l3', 'r2', 'r3'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1'))))
104
+ for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
105
+
106
+ # ***** Apple Metal *****
107
+
108
+ metal = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do,
109
+ opts=("u0","l0","l1","l1","l0","l1"),
110
+ swizzle=((('r1', 'l1', 'l2', 'r2', 'l4'), ('r0',), ('u0', 'l0', 'l3')),
111
+ (('l0', 'r0', 'r1', 'l3', 'r2'), ('u0',), ('l1', 'l2', 'l4'))))
112
+ for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
113
+ (dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
114
+
115
+ # ***** Apple AMX *****
116
+
117
+ amx = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
118
+ swizzle=(((), ('u0', 'u1', 'u2', 'u3', 'u4', 'u5', 'u6', 'u7'), ()),
119
+ ((), ('u4', 'u5', 'u6', 'u7', 'u0', 'u1', 'u2', 'u3'), ())),
120
+ opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
121
+
122
+ # ***** Intel ****
123
+
124
+ intel = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
125
+ opts=("l0","l0","l0","u1","u1","u1"),
126
+ swizzle=((('r1', 'r2', 'r3'), ('u0', 'u1', 'u2'), ('l0', 'l1', 'l2', 'r0')),
127
+ (('l0', 'l1', 'l2'), ('r1', 'r2', 'r3'), ('u0', 'u1', 'u2', 'r0'))))]
@@ -0,0 +1,67 @@
1
+ from tinygrad.dtype import dtypes, least_upper_dtype
2
+ from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
3
+ from tinygrad.uop.symbolic import symbolic
4
+
5
+ # **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
6
+ # this is badly tested and low quality. remove it?
7
+
8
+ FP = (1 << 15)
9
+ pm_quant = symbolic+PatternMatcher([
10
+ # cast after add/mul
11
+ (UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
12
+ lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
13
+ (UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
14
+ lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
15
+
16
+ # masked MUL after masked ADD
17
+ ((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
18
+ lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
19
+
20
+ # MUL after reduce
21
+ (UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c.arg),
22
+ # CAST after reduce (doesn't work if it's a size change)
23
+ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
24
+ lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
25
+
26
+ # x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
27
+ (UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
28
+ lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
29
+ # mul 0 * c1 is 0
30
+ (UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
31
+ UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
32
+ # mul (with plus) 0 * c1 is 0
33
+ (UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
34
+ (UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
35
+ UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
36
+ lambda ld,v,c1: ld*c1),
37
+
38
+ # const push through add
39
+ ((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
40
+
41
+ # fixed point mult, replace (x.float()*c1+c2).int() with an int expression
42
+ ((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
43
+ lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
44
+ # fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
45
+ ((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
46
+ lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
47
+
48
+ # where move
49
+ (UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
50
+ (yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
51
+ ((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
52
+ (UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
53
+ (x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
54
+ ((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
55
+ UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
56
+ x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
57
+
58
+ # where on two adds
59
+ (UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
60
+ lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
61
+
62
+ # split REDUCE into multiple reduces (who remembers FOIL?)
63
+ (UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2"),), name="r"),
64
+ lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
65
+ (UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
66
+ lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
67
+ ])