tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,10 @@
1
1
  from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
2
- import struct
2
+ import struct, math
3
3
  from collections import defaultdict
4
4
  from tinygrad.helpers import DEBUG
5
- from tinygrad.codegen.linearizer import UOps, UOp
6
5
  from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
7
6
  from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
8
- from tinygrad.codegen.uops import UOpGraph, PatternMatcher
7
+ from tinygrad.codegen.uops import UOps, UOp, UOpGraph, PatternMatcher, UPat
9
8
  from tinygrad.renderer import Renderer, TensorCore
10
9
 
11
10
  def render_val(x, dtype):
@@ -18,8 +17,8 @@ def render_val(x, dtype):
18
17
  class PTXRenderer(Renderer):
19
18
  device = "CUDA"
20
19
  suffix = "PTX"
21
- global_max = [65535, 65535, 2147483647]
22
- local_max = [64, 1024, 1024]
20
+ global_max = (2147483647, 65535, 65535)
21
+ local_max = (1024, 1024, 64)
23
22
  shared_max = 49152
24
23
  tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
25
24
  def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
@@ -30,29 +29,28 @@ class PTXRenderer(Renderer):
30
29
  .address_size 64
31
30
  .visible .entry"""
32
31
  barrier = "bar.sync\t0;"
33
- has_pred = True
34
- load_global = True
35
- label_prefix = "$"
36
32
  gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
37
33
  gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
38
34
  lid = [f'%tid.{chr(120+i)}' for i in range(3)]
39
35
  asm_for_op: Dict[Op, Callable] = {
40
- UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
36
+ UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) \
37
+ else f"neg.{name} {d}, {a};",
38
+ UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
41
39
  UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
42
40
  UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
41
+ BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
43
42
  BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
44
- BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
45
43
  BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
46
44
  BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
47
- BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};",
45
+ BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
48
46
  BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
49
47
  BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
50
- BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
48
+ BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
51
49
  TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
52
50
  TernaryOps.WHERE: lambda d,a,b,c,dt,name:
53
51
  f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
54
52
  }
55
- supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
53
+ supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
56
54
  TernaryOps.WHERE]
57
55
  # HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
58
56
  types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
@@ -74,7 +72,7 @@ class PTXRenderer(Renderer):
74
72
 
75
73
  def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
76
74
 
77
- def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
75
+ def render_bra(self, b1, pred=None) -> List[str]: return [f"@{pred} bra {b1};"] if pred else [f"bra {b1};"]
78
76
 
79
77
  def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
80
78
  assert dtype != dtypes.bool
@@ -118,14 +116,6 @@ class PTXRenderer(Renderer):
118
116
  if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
119
117
  return f"%{prefix}{c[prefix]-1}"
120
118
 
121
- c_label: DefaultDict[str, int] = defaultdict(int)
122
- r_label: Dict[UOp, str] = {}
123
- def ssa_label(prefix:str, u:UOp):
124
- nonlocal c_label, r_label
125
- c_label[prefix] += 1
126
- r_label[u] = f"{self.label_prefix}{prefix}_{c_label[prefix]-1}"
127
- return r_label[u]
128
-
129
119
  def const(x:ConstType, dtype:DType, mov=False):
130
120
  if mov or dtype in self.const_requires_mov:
131
121
  kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
@@ -140,42 +130,42 @@ class PTXRenderer(Renderer):
140
130
  return ret
141
131
 
142
132
  for u in uops:
143
- uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
133
+ uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
144
134
  if uop is UOps.IF:
145
- assert vin[0].dtype is not None
146
- kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
135
+ assert src[0].dtype is not None
136
+ kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
147
137
  elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
148
138
  elif uop is UOps.ENDRANGE:
149
- kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
150
- self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
151
- kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
139
+ kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
140
+ self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
141
+ kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
152
142
  elif uop is UOps.ENDIF:
153
- kk(f"{r_label[vin[0]]}:")
143
+ kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
154
144
  elif uop is UOps.STORE:
155
- assert vin[0].dtype is not None and vin[2].dtype is not None
156
- assert vin[0].dtype == dtypes.int64, "store isn't int64"
157
- assert vin[1].uop is UOps.CONST, f"store isn't const {u}"
158
- mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
159
- if vin[2].dtype.count > 1:
160
- kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
161
- f"st{mem_type}.v{vin[2].dtype.count}.{self.mem_types[vin[2].dtype.scalar()]} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
145
+ assert src[0].dtype is not None and src[2].dtype is not None
146
+ assert src[0].dtype == dtypes.int64, "store isn't int64"
147
+ assert src[1].op is UOps.CONST, f"store isn't const {u}"
148
+ mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
149
+ if src[2].dtype.count > 1:
150
+ kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
151
+ f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
162
152
  else:
163
- kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg))
153
+ kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
164
154
  else:
165
155
  assert dtype is not None, f"None dtype for uop {uop}"
166
- if uop is UOps.RANGE: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
156
+ if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
167
157
  elif uop is UOps.ALU:
168
- assert vin[0].dtype is not None
169
- if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
158
+ assert src[0].dtype is not None
159
+ if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
170
160
  # pass in the other dtype here
171
- kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, self.types[vin[0].dtype]))
161
+ kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
172
162
  else:
173
- kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, self.types[dtype]))
163
+ kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
174
164
  elif uop is UOps.DEFINE_ACC:
175
165
  if dtype.count > 1:
176
166
  r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
177
- for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(args[0], dtype.scalar())};")
178
- else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(args[0], dtype)};")
167
+ for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
168
+ else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
179
169
  elif uop is UOps.SPECIAL:
180
170
  assert args[1][0] != "i", "idx not supported"
181
171
  kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
@@ -184,30 +174,30 @@ class PTXRenderer(Renderer):
184
174
  elif uop is UOps.CONST:
185
175
  if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
186
176
  else: r[u] = const(args, dtype, mov=True)
187
- elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
177
+ elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
188
178
  elif uop is UOps.LOAD:
189
- assert vin[0].dtype == dtypes.int64, "load isn't int64"
190
- assert vin[1].uop is UOps.CONST, f"load isn't const {u}"
191
- mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
179
+ assert src[0].dtype == dtypes.int64, "load isn't int64"
180
+ assert src[1].op is UOps.CONST, f"load isn't const {u}"
181
+ mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
192
182
  if dtype.count > 1:
193
183
  r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
194
- if(len(vin)>3):
184
+ if(len(src)>3):
195
185
  for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
196
- kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
197
- + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
186
+ kk((f"@{r[src[2]]}"if len(src) > 3 else "")
187
+ + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
198
188
  else:
199
- kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
200
- alt=r[vin[3]] if len(vin) > 3 else None, ss=mem_type, offset=vin[1].arg))
189
+ kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
190
+ alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
201
191
  elif uop is UOps.PHI:
202
192
  if dtype.count > 1:
203
- for x0, x1 in zip(r[vin[0]], r[vin[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
193
+ for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
204
194
  else:
205
- kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
206
- r[u] = r[vin[0]]
195
+ kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
196
+ r[u] = r[src[0]]
207
197
  elif uop in {UOps.CAST, UOps.BITCAST}:
208
- assert vin[0].dtype is not None
209
- if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
210
- else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
198
+ assert src[0].dtype is not None
199
+ if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
200
+ else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
211
201
  elif uop is UOps.DEFINE_LOCAL:
212
202
  # TODO: we should sum these, and fetch 0xC000 from somewhere
213
203
  assert args[1]*dtype.itemsize <= 0xC000, "too large local"
@@ -215,62 +205,65 @@ class PTXRenderer(Renderer):
215
205
  elif uop is UOps.DEFINE_VAR:
216
206
  bufs.append((args.expr, dtype))
217
207
  r[u] = f"%{args.expr}"
218
- if self.load_global: kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
208
+ kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
219
209
  elif uop is UOps.DEFINE_GLOBAL:
220
210
  bufs.append((nm:=f"data{args[0]}", dtype))
221
211
  r[u] = f"%{nm}"
222
- if self.load_global:
223
- dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
224
- kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
212
+ dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
213
+ kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
225
214
  elif uop is UOps.WMMA:
226
215
  wmma = []
227
- for vv in vin[:2]:
216
+ for vv in src[:2]:
228
217
  for i in range(0, len(r[vv]), 2):
229
218
  wmma.append(ssa("wmma", dtype="b32"))
230
219
  kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
231
220
  r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
232
221
  kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
233
- {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[vin[2]])}}};')
222
+ {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};')
234
223
  else: raise NotImplementedError(f"no code for {uop}")
235
224
 
236
225
  return self.render_kernel(kernel, name, bufs, c.items())
237
226
 
238
227
  ptx_matcher = PatternMatcher([
239
- ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
240
- lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
241
- ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
242
- lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
243
- ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD,
244
- "vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
245
- lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
246
- *[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
247
- lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
228
+ (UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
229
+ src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
230
+ lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL)),
231
+ (UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
232
+ src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
233
+ lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR)),
234
+ (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
235
+ (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
236
+ lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
237
+ (UPat(UOps.ALU, BinaryOps.ADD,
238
+ [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
239
+ lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
240
+ *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
241
+ lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
248
242
  for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
249
- ({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
250
- "vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
251
- lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
252
- ({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
253
- lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
254
- ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
255
- lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
256
- ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
257
- lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
258
- ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g", "dtype": dtypes.int})},
259
- lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
243
+ (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
244
+ lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
245
+ (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
246
+ lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
247
+ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
248
+ lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
249
+ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
250
+ lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
251
+ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
252
+ lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
260
253
  # ptr_ar (load/store)
261
- ({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
262
- {"uop": UOps.ALU, "arg": BinaryOps.ADD,"vin":[{"__name__": "alu"}, {"__name__": "const", "uop":UOps.CONST}]})},
263
- lambda root, alu, const: UOp(root.uop, root.dtype,
264
- (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
265
- UOp.const(const.dtype, root.vin[0].dtype.itemsize)*const)+root.vin[2:])),
266
- ({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
267
- {"__name__": "const", "uop":UOps.CONST})},
268
- lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64),
269
- UOp.const(dtypes.int64, const.arg * root.vin[0].dtype.itemsize),
270
- )+root.vin[2:])),
271
- ({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
272
- {"__name__": "alu"})}, # no const here
273
- lambda root, alu: UOp(root.uop, root.dtype,
274
- (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
275
- UOp.const(dtypes.int64, 0))+root.vin[2:])),
254
+ (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
255
+ UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
256
+ lambda root, alu, const: UOp(root.op, root.dtype,
257
+ (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
258
+ UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
259
+ (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
260
+ UPat(UOps.CONST, name="const"))),
261
+ lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
262
+ UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
263
+ )+root.src[2:])),
264
+ (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
265
+ UPat(name="alu"))), # no const here
266
+ lambda root, alu: UOp(root.op, root.dtype,
267
+ (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
268
+ UOp.const(dtypes.int64, 0))+root.src[2:])),
276
269
  ])
@@ -1,11 +1,10 @@
1
1
  from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
2
2
  import os, math
3
3
  from collections import defaultdict, Counter
4
- from tinygrad.codegen.linearizer import UOps, UOp
5
4
  from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
6
- from tinygrad.helpers import strip_parens, getenv, prod
5
+ from tinygrad.helpers import strip_parens, getenv, prod, dedup
7
6
  from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
8
- from tinygrad.codegen.uops import UOpGraph
7
+ from tinygrad.codegen.uops import UOps, UOp, UOpGraph
9
8
  from tinygrad.renderer import Renderer, TensorCore
10
9
 
11
10
  class CStyleLanguage(Renderer):
@@ -25,10 +24,11 @@ class CStyleLanguage(Renderer):
25
24
  type_map: Dict[DType, str] = {}
26
25
  code_for_op: Dict = {
27
26
  UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
27
+ UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
28
28
  UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
29
- BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
30
- BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
31
- BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
29
+ BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
30
+ BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
31
+ BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
32
32
  TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
33
33
 
34
34
  # returns a str expression of the casted xs with the given type
@@ -103,31 +103,32 @@ class CStyleLanguage(Renderer):
103
103
  c[prefix] += 1
104
104
  return ret
105
105
 
106
- child_count = Counter(v for ru in uops for v in ru.vin)
106
+ child_count = Counter(v for ru in uops for v in ru.src)
107
107
 
108
+ seen_vars = set()
108
109
  for u in uops:
109
- uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
110
+ uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
110
111
  # these four uops don't have output dtypes
111
112
  if uop is UOps.IF:
112
- kk(f"if ({r[vin[0]]}) {{")
113
+ kk(f"if ({r[src[0]]}) {{")
113
114
  depth += 1
114
115
  elif uop is UOps.BARRIER: kk(self.barrier)
115
116
  elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
116
117
  depth -= 1
117
118
  kk("}")
118
119
  elif uop is UOps.STORE:
119
- assert vin[0].dtype is not None and vin[2].dtype is not None
120
- rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
121
- kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
120
+ assert src[0].dtype is not None and src[2].dtype is not None
121
+ rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
122
+ kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
122
123
  else:
123
124
  assert dtype is not None, f"None dtype for uop {uop}"
124
125
  if uop is UOps.RANGE:
125
- kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
126
+ kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
126
127
  depth += 1
127
128
  elif uop is UOps.ALU:
128
129
  # remove parens if ALU types are the same. TODO: can do more here
129
- if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
130
- else: operands = [r[v] for v in vin]
130
+ if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
131
+ else: operands = [r[v] for v in src]
131
132
  val = self.code_for_op[args](*operands, dtype)
132
133
  assert child_count[u] != 0, f"childless ALU op found {u}"
133
134
  # TODO: fix index rendering issue. fix clang nested max macro issue
@@ -137,39 +138,41 @@ class CStyleLanguage(Renderer):
137
138
  kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
138
139
  r[u] = args[1]
139
140
  elif uop is UOps.LOAD:
140
- val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
141
+ val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
141
142
  # NOTE: this relies on the load not happening if it's in the unselected branch
142
- if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
143
+ if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
143
144
  kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
144
145
  elif uop is UOps.PHI:
145
- kk(f"{r[vin[0]]} = {r[vin[1]]};")
146
- r[u] = r[vin[0]]
146
+ kk(f"{r[src[0]]} = {r[src[1]]};")
147
+ r[u] = r[src[0]]
147
148
  elif uop in {UOps.CAST, UOps.BITCAST}:
148
149
  if uop is UOps.BITCAST:
149
- assert len(vin) == 1
150
+ assert len(src) == 1
150
151
  precast = ssa('precast')
151
- kk(f"{self.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
152
+ kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
152
153
  val = self.render_cast([precast], dtype, bitcast=True)
153
154
  else:
154
- val = self.render_cast([r[x] for x in vin], dtype, bitcast=False)
155
+ val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
155
156
  if child_count[u] <= 1: r[u] = val
156
157
  else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
157
158
  elif uop is UOps.DEFINE_LOCAL:
158
159
  kk(self.render_local(args[0], dtype, args[1]))
159
160
  r[u] = args[0]
160
161
  elif uop is UOps.DEFINE_VAR:
162
+ assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
163
+ seen_vars.add(args.expr)
161
164
  bufs.append((args.expr, (dtype,False)))
162
165
  r[u] = args.expr
163
166
  elif uop is UOps.DEFINE_GLOBAL:
164
167
  bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
165
168
  r[u] = nm
166
- elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
167
- elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args[0], dtype)};")
169
+ elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
170
+ elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
168
171
  elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
169
172
  elif uop is UOps.GEP:
170
- assert vin[0].dtype is not None
171
- from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
172
- r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
173
+ assert src[0].dtype is not None
174
+ from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
175
+ r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
173
176
  else: raise RuntimeError(f"failed to render {uop}")
174
177
 
175
178
  return self.render_kernel(name, kernel, bufs, uops)
@@ -178,6 +181,7 @@ class ClangRenderer(CStyleLanguage):
178
181
  device = "CLANG"
179
182
  supports_float4 = False
180
183
  has_local = False
184
+ global_max = None
181
185
 
182
186
  # language options
183
187
  buffer_suffix = " restrict"
@@ -219,6 +223,7 @@ class MetalRenderer(CStyleLanguage):
219
223
  float4 = "float4"
220
224
  uses_ptr_arithmetic = True
221
225
  code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
226
+ # uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
222
227
  extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
223
228
  type_map = {dtypes.bfloat16: "bfloat"}
224
229
  code_for_op = {**CStyleLanguage().code_for_op,
@@ -232,14 +237,15 @@ class MetalRenderer(CStyleLanguage):
232
237
  return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
233
238
 
234
239
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
235
- prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop is UOps.WMMA])
240
+ prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
236
241
  for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
237
242
  simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
238
243
  b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
239
244
  return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
240
245
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
241
246
 
242
- code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
247
+ code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
248
+ BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
243
249
  UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
244
250
  UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
245
251
  UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
@@ -252,8 +258,8 @@ def _make_cuda_dtype(base_type, name, cnt):
252
258
 
253
259
  class CUDARenderer(CStyleLanguage):
254
260
  device = "CUDA"
255
- global_max = [65535, 65535, 2147483647]
256
- local_max = [64, 1024, 1024]
261
+ global_max = (2147483647, 65535, 65535)
262
+ local_max = (1024, 1024, 64)
257
263
  shared_max = 49152
258
264
  tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
259
265
  def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
@@ -281,7 +287,7 @@ class CUDARenderer(CStyleLanguage):
281
287
  prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
282
288
 
283
289
  # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
284
- for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]):
290
+ for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
285
291
  fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
286
292
  prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
287
293
  asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
@@ -312,8 +318,8 @@ def _make_hip_dtype(base_type, name, cnt):
312
318
  return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
313
319
  f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
314
320
 
315
- class HIPRenderer(CStyleLanguage):
316
- device = "HSA"
321
+ class AMDRenderer(CStyleLanguage):
322
+ device = "AMD"
317
323
  shared_max = 65536
318
324
  tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
319
325
 
@@ -346,18 +352,18 @@ f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt},
346
352
  if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
347
353
  struct hip_bfloat16 {
348
354
  unsigned short data;
349
- __attribute__((device)) hip_bfloat16(float val) {
355
+ inline __attribute__((device)) hip_bfloat16(float val) {
350
356
  union { float fp32; unsigned int u32; } u = {val};
351
357
  if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
352
358
  data = (u.u32 >> 16);
353
359
  }
354
- __attribute__((device)) operator float() const {
360
+ inline __attribute__((device)) operator float() const {
355
361
  unsigned int uval = data << 16;
356
362
  return *reinterpret_cast<float*>(&uval);
357
363
  }
358
364
  };
359
- static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
360
- static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
365
+ static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
366
+ static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
361
367
  """)
362
368
 
363
369
  if any(uop.dtype == dtypes.half for uop in uops):
@@ -366,19 +372,18 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
366
372
 
367
373
  prefix += [_make_hip_dtype(*x) for x in vec_dts]
368
374
 
369
- for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
375
+ for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
370
376
  if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
371
- else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
377
+ else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
372
378
  half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
373
379
  c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
374
380
  for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
375
381
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
376
382
 
377
383
  def get_kernel_modifier(self, uops:UOpGraph) -> str:
378
- requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop is UOps.SPECIAL and u.arg[1][0] == "l")
384
+ requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
379
385
  # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
380
386
  # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
381
387
  return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
382
388
 
383
389
  class NVRenderer(CUDARenderer): device = "NV"
384
- class AMDRenderer(HIPRenderer): device = "AMD"