tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ from collections import defaultdict, Counter
4
4
  from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
5
5
  from tinygrad.helpers import strip_parens, getenv, prod, dedup
6
6
  from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
7
- from tinygrad.codegen.uops import UOps, UOp, UOpGraph
7
+ from tinygrad.codegen.uops import UOps, UOp
8
8
  from tinygrad.renderer import Renderer, TensorCore
9
9
 
10
10
  class CStyleLanguage(Renderer):
@@ -22,6 +22,8 @@ class CStyleLanguage(Renderer):
22
22
  uses_vload: bool = False
23
23
  uses_ptr_arithmetic: bool = False
24
24
  type_map: Dict[DType, str] = {}
25
+ infinity: str = "INFINITY"
26
+ nan: str = "NAN"
25
27
  code_for_op: Dict = {
26
28
  UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
27
29
  UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
@@ -29,24 +31,30 @@ class CStyleLanguage(Renderer):
29
31
  BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
30
32
  BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
31
33
  BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
34
+ BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
32
35
  TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
33
36
 
34
37
  # returns a str expression of the casted xs with the given type
35
- def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
36
- if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
37
- if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
38
+ def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
39
+ if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))"
40
+ return f"({self.render_dtype(var_dtype)})({x})"
41
+
42
+ # returns a str expression of the vectorized xs with the given type
43
+ def render_vectorize(self, x:List[str], var_dtype:DType) -> str:
38
44
  assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
39
45
  assert self.float4 is not None, "vectorized cast is not supported on this platform"
40
- return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
46
+ return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}" + (f"{{{','.join(x)}}}" if self.device == "CLANG" else f"({','.join(x)})")
41
47
 
42
48
  # returns a str expression of the const with the given type
43
49
  def render_const(self, x:ConstType, dtype:DType) -> str:
44
- if math.isnan(x): val = "NAN"
45
- elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
50
+ assert dtype.count == 1, f"consts should be scalar, got {dtype}"
51
+ if math.isnan(x): val = self.nan
52
+ elif math.isinf(x): val = ("-" if x < 0 else "") + self.infinity
46
53
  elif dtype == dtypes.bool: val = "1" if x else "0"
47
54
  elif dtype == dtypes.float: val = f"{x}f"
55
+ elif dtype == dtypes.uint64: val = f"{x}ULL"
48
56
  else: val = str(x)
49
- return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
57
+ return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
50
58
 
51
59
  # returns a str expression of the loaded value with the output type
52
60
  def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
@@ -56,13 +64,11 @@ class CStyleLanguage(Renderer):
56
64
  if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
57
65
  return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
58
66
  if output_dtype.count > 1:
59
- out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
60
- else:
61
- out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
62
- return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
67
+ return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
68
+ return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
63
69
 
64
- def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
65
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
70
+ def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
71
+ def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
66
72
  tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
67
73
  buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
68
74
  ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
@@ -81,15 +87,16 @@ class CStyleLanguage(Renderer):
81
87
  return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
82
88
  if var_dtype.count > 1:
83
89
  prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
84
- return f"*(({prefix}{self.render_dtype(buf_dtype)}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};"
90
+ return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};"
85
91
  return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
86
92
 
87
93
  def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
88
- def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
94
+ def render_dtype(self, var_dtype:DType) -> str:
95
+ return self.type_map.get(scalar:=var_dtype.scalar(), scalar.name) + (str(var_dtype.count) if (var_dtype.count) > 1 else "")
89
96
 
90
- def render(self, name:str, uops:UOpGraph) -> str:
97
+ def render(self, name:str, uops:List[UOp]) -> str:
91
98
  kernel = []
92
- bufs: List[Tuple[str, Tuple[DType, bool]]] = []
99
+ bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
93
100
  depth = 1
94
101
  def kk(s): kernel.append(" "*depth+s)
95
102
 
@@ -118,6 +125,8 @@ class CStyleLanguage(Renderer):
118
125
  kk("}")
119
126
  elif uop is UOps.STORE:
120
127
  assert src[0].dtype is not None and src[2].dtype is not None
128
+ # mark DEFINE_GLOBAL buf as writable
129
+ if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
121
130
  rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
122
131
  kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
123
132
  else:
@@ -128,6 +137,7 @@ class CStyleLanguage(Renderer):
128
137
  elif uop is UOps.ALU:
129
138
  # remove parens if ALU types are the same. TODO: can do more here
130
139
  if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
140
+ elif args is BinaryOps.MAX: operands = [self.render_cast(r[v], cast(DType, v.dtype)) if v.op is UOps.CONST else r[v] for v in src]
131
141
  else: operands = [r[v] for v in src]
132
142
  val = self.code_for_op[args](*operands, dtype)
133
143
  assert child_count[u] != 0, f"childless ALU op found {u}"
@@ -135,58 +145,71 @@ class CStyleLanguage(Renderer):
135
145
  if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
136
146
  else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
137
147
  elif uop is UOps.SPECIAL:
138
- kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
139
- r[u] = args[1]
148
+ kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */")
149
+ r[u] = args[0]
150
+ elif uop is UOps.DEFINE_VAR:
151
+ assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
152
+ seen_vars.add(args.expr)
153
+ bufs[u] = (args.expr, (dtype,False))
154
+ r[u] = args.expr
140
155
  elif uop is UOps.LOAD:
141
156
  val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
142
157
  # NOTE: this relies on the load not happening if it's in the unselected branch
143
- if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
158
+ if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype)
144
159
  kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
145
160
  elif uop is UOps.PHI:
146
161
  kk(f"{r[src[0]]} = {r[src[1]]};")
147
162
  r[u] = r[src[0]]
148
- elif uop in {UOps.CAST, UOps.BITCAST}:
163
+ elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}:
164
+ assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation"
149
165
  if uop is UOps.BITCAST:
150
- assert len(src) == 1
151
166
  precast = ssa('precast')
152
167
  kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
153
- val = self.render_cast([precast], dtype, bitcast=True)
154
- else:
155
- val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
168
+ val = self.render_cast(precast, dtype, bitcast=True)
169
+ elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False)
170
+ else: val = self.render_vectorize([r[x] for x in src], dtype)
156
171
  if child_count[u] <= 1: r[u] = val
157
172
  else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
158
173
  elif uop is UOps.DEFINE_LOCAL:
159
174
  kk(self.render_local(args[0], dtype, args[1]))
160
175
  r[u] = args[0]
161
- elif uop is UOps.DEFINE_VAR:
162
- assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
163
- seen_vars.add(args.expr)
164
- bufs.append((args.expr, (dtype,False)))
165
- r[u] = args.expr
166
176
  elif uop is UOps.DEFINE_GLOBAL:
167
- bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
177
+ bufs[u] = (nm:=f"data{args}", (dtype, False))
168
178
  r[u] = nm
169
179
  elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
170
- elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
180
+ elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {r[src[0]]};")
171
181
  elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
172
182
  elif uop is UOps.GEP:
173
183
  assert src[0].dtype is not None
174
184
  from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
175
- r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
176
- else: raise RuntimeError(f"failed to render {uop}")
185
+ r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \
186
+ (f"[{args}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) or self.device == 'CLANG' else f".{'xyzwabcd'[args]}")
187
+ else: raise RuntimeError(f"failed to render {u}")
188
+
189
+ # NOTE: this relies on bufs dict preserving order
190
+ return self.render_kernel(name, kernel, list(bufs.values()), uops)
177
191
 
178
- return self.render_kernel(name, kernel, bufs, uops)
192
+ def _make_clang_dtype(self, dtype):
193
+ return f"typedef {self.render_dtype(dtype.scalar())} {self.render_dtype(dtype)} __attribute__((aligned({(sz:=dtype.itemsize)}),vector_size({sz})));"
179
194
 
180
195
  class ClangRenderer(CStyleLanguage):
181
196
  device = "CLANG"
182
- supports_float4 = False
197
+ float4 = "(float4)"
183
198
  has_local = False
184
199
  global_max = None
200
+ infinity = "__builtin_inff()"
201
+ nan = '__builtin_nanf("")'
185
202
 
186
203
  # language options
187
204
  buffer_suffix = " restrict"
188
205
  type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
189
- code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
206
+ code_for_op = {**({k:v for k,v in CStyleLanguage().code_for_op.items() if k not in [UnaryOps.EXP2, UnaryOps.SIN, UnaryOps.LOG2]}),
207
+ UnaryOps.SQRT: lambda x,dtype: f"__builtin_sqrtl({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
208
+ BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
209
+
210
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
211
+ prefix = [_make_clang_dtype(self, dtype) for dtype in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype.count>1)]
212
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix)
190
213
 
191
214
  class OpenCLRenderer(CStyleLanguage):
192
215
  device = "GPU"
@@ -202,7 +225,7 @@ class OpenCLRenderer(CStyleLanguage):
202
225
  uses_vload = True
203
226
  type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
204
227
  def render_cast(self, x, var_dtype, bitcast=False) -> str:
205
- return f"as_{self.render_dtype(var_dtype)}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
228
+ return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
206
229
 
207
230
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
208
231
  if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
@@ -211,7 +234,7 @@ class OpenCLRenderer(CStyleLanguage):
211
234
  class MetalRenderer(CStyleLanguage):
212
235
  device = "METAL"
213
236
  shared_max = 32768
214
- tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
237
+ tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
215
238
  def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
216
239
 
217
240
  # language options
@@ -222,7 +245,7 @@ class MetalRenderer(CStyleLanguage):
222
245
  barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
223
246
  float4 = "float4"
224
247
  uses_ptr_arithmetic = True
225
- code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
248
+ code_for_workitem = {"g": lambda x: f"gid.{chr(120+int(x))}", "l": lambda x: f"lid.{chr(120+int(x))}"}
226
249
  # uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
227
250
  extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
228
251
  type_map = {dtypes.bfloat16: "bfloat"}
@@ -231,10 +254,10 @@ class MetalRenderer(CStyleLanguage):
231
254
  UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
232
255
  UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
233
256
  UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
234
- UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
257
+ UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
235
258
 
236
- def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
237
- return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
259
+ def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
260
+ return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
238
261
 
239
262
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
240
263
  prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
@@ -252,16 +275,17 @@ code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dt
252
275
  UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
253
276
 
254
277
  _nms = "xyzwabcdefghijkl"
255
- def _make_cuda_dtype(base_type, name, cnt):
256
- vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
257
- return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
278
+ def _make_cuda_dtype(renderer:CStyleLanguage, dtype:DType) -> str:
279
+ vec, scal = renderer.render_dtype(dtype), renderer.render_dtype(dtype.scalar()),
280
+ elems, header = ', '.join(_nms[:dtype.count]), ', '.join([f"{scal} {x}" for x in _nms[:dtype.count]])
281
+ return f"struct __align__({dtype.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
258
282
 
259
283
  class CUDARenderer(CStyleLanguage):
260
284
  device = "CUDA"
261
285
  global_max = (2147483647, 65535, 65535)
262
286
  local_max = (1024, 1024, 64)
263
287
  shared_max = 49152
264
- tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
288
+ tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
265
289
  def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
266
290
 
267
291
  # language options
@@ -270,25 +294,22 @@ class CUDARenderer(CStyleLanguage):
270
294
  smem_prefix_for_cast = False
271
295
  barrier = "__syncthreads();"
272
296
  float4 = "make_float4"
273
- code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
274
- "i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
297
+ code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
298
+ "i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+x)}+threadIdx.{chr(120+int(x))})"}
275
299
  code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
276
300
  type_map = {dtypes.bfloat16: "nv_bfloat16"}
277
301
 
278
302
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
279
303
  # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
280
- dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
281
-
282
304
  prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
283
- if any(uop.dtype == dtypes.half for uop in uops):
284
- prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
285
305
 
286
- if any(uop.dtype == dtypes.bfloat16 for uop in uops):
287
- prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
306
+ for dtype in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype in (dtypes.half, dtypes.bfloat16)):
307
+ prefix += [f"#include <cuda_{'fp' if dtype == dtypes.half else 'bf'}16.h>"] + [_make_cuda_dtype(self, dtype.vec(sz)) for sz in [4, 8]]
288
308
 
289
309
  # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
310
+ dt_map = { dtypes.float: "f32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
290
311
  for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
291
- fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
312
+ fn, ti, to, ci, co = arg[0], self.render_dtype(arg[2]), self.render_dtype(arg[3]), dt_map[arg[2]], dt_map[arg[3]]
292
313
  prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
293
314
  asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
294
315
  : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
@@ -296,6 +317,11 @@ return c;}}""")
296
317
 
297
318
  return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
298
319
 
320
+ def get_kernel_modifier(self, uops:List[UOp]) -> str:
321
+ maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l")
322
+ # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
323
+ return f"__launch_bounds__({maxThreadsPerBlock}) "
324
+
299
325
  code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
300
326
  UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
301
327
  UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
@@ -321,7 +347,7 @@ def _make_hip_dtype(base_type, name, cnt):
321
347
  class AMDRenderer(CStyleLanguage):
322
348
  device = "AMD"
323
349
  shared_max = 65536
324
- tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
350
+ tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
325
351
 
326
352
  # language options
327
353
  kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
@@ -380,10 +406,11 @@ static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat
380
406
  for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
381
407
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
382
408
 
383
- def get_kernel_modifier(self, uops:UOpGraph) -> str:
384
- requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
409
+ def get_kernel_modifier(self, uops:List[UOp]) -> str:
410
+ requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l")
385
411
  # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
386
412
  # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
387
413
  return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
388
414
 
389
415
  class NVRenderer(CUDARenderer): device = "NV"
416
+ class HIPRenderer(AMDRenderer): device = "HIP"
@@ -1,32 +1,14 @@
1
- from typing import Final, Dict, Callable, Any, List, Optional
1
+ from typing import Dict, Callable, Any, List, Optional
2
2
  from llvmlite import ir
3
3
  from tinygrad.dtype import DType, PtrDType, dtypes
4
4
  from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
5
- from tinygrad.codegen.uops import UOps, UOp, UOpGraph
5
+ from tinygrad.codegen.uops import UOps, UOp
6
6
  from tinygrad.renderer import Renderer
7
7
 
8
8
  MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
9
9
 
10
10
  def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
11
11
 
12
- code_for_op: Final[Dict[Op, Callable]] = {
13
- UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
14
- (builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
15
- UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
16
- UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
17
- UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
18
- UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
19
- UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
20
- BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
21
- BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
22
- BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
23
- BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
24
- BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
25
- BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
26
- BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
27
- BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
28
- TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
29
-
30
12
  dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
31
13
  dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
32
14
  dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
@@ -71,8 +53,22 @@ class LLVMRenderer(Renderer):
71
53
  has_local = False
72
54
  has_shared = False
73
55
  global_max = None
74
-
75
- def render(self, name:str, uops:UOpGraph) -> str:
56
+ code_for_op: Dict[Op, Callable] = {
57
+ UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
58
+ (builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
59
+ UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
60
+ UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
61
+ BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
62
+ BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
63
+ BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
64
+ BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
65
+ BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
66
+ BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
67
+ BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
68
+ BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501
69
+ TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
70
+
71
+ def render(self, name:str, uops:List[UOp]) -> str:
76
72
  # all llvm stuff goes into a module
77
73
  module = ir.Module(name=__file__)
78
74
 
@@ -86,10 +82,6 @@ class LLVMRenderer(Renderer):
86
82
  for a in func.args:
87
83
  if a.type.is_pointer: a.add_attribute("noalias")
88
84
 
89
- # add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
90
- func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
91
- func.attributes.add('"no-nans-fp-math"="true"')
92
-
93
85
  bb = [ir.IRBuilder(func.append_basic_block("entry"))]
94
86
  loop_blocks: List = []
95
87
  reduce_phis: List = []
@@ -136,9 +128,9 @@ class LLVMRenderer(Renderer):
136
128
  reduce_phis.append(u)
137
129
  elif uop is UOps.LOAD:
138
130
  if len(src) > 2:
139
- aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
131
+ aug_idx = bb[-1].select(lvars[src[3]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
140
132
  val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
141
- val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
133
+ val = bb[-1].select(lvars[src[3]], val, lvars[src[2]])
142
134
  else:
143
135
  val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
144
136
  lvars[u] = val
@@ -149,10 +141,9 @@ class LLVMRenderer(Renderer):
149
141
  while backward.op is UOps.PHI: backward = backward.src[0]
150
142
  lvars[backward] = lvars[u]
151
143
  elif uop is UOps.ALU:
152
- lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
144
+ lvars[u] = self.code_for_op[args](bb[-1], *[lvars[x] for x in src], src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
153
145
  elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
154
146
  elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
155
- elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
156
147
  elif uop is UOps.CONST: lvars[u] = const(args, dtype)
157
148
  else: raise RuntimeError(f"failed to render {uop}")
158
149