tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
5
5
  from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
6
6
  from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
7
7
  from tinygrad.renderer import Renderer, TensorCore
8
+ from tinygrad.codegen.devectorizer import no_vectorized_alu
8
9
 
9
10
  base_rewrite = PatternMatcher([
10
11
  (UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
@@ -17,7 +18,9 @@ base_rewrite = PatternMatcher([
17
18
  lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
18
19
  (UPat(Ops.VECTORIZE, name="x"),
19
20
  lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
20
- (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
21
+ (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CPU', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
22
+ (UPat(Ops.CAST, name="x"), lambda ctx,x:
23
+ f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
21
24
  (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
22
25
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
23
26
  (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
@@ -49,7 +52,10 @@ base_rewrite = PatternMatcher([
49
52
  (UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
50
53
  *([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
51
54
  (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
52
- (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
55
+ (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CPU', 'DSP'} else \
56
+ f".{'xyzwabcd'[x.arg[0]]}")),
57
+ # custom passes through with format
58
+ (UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
53
59
  ])
54
60
 
55
61
  extra_pm = PatternMatcher([
@@ -58,6 +64,12 @@ extra_pm = PatternMatcher([
58
64
  lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
59
65
  # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
60
66
  (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
67
+ # devectorize any bools
68
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
69
+ # CAST (from bool) can't be vectorized
70
+ (UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
71
+ # WHERE can't be vectorized
72
+ (UPat(Ops.WHERE, name="alu"), no_vectorized_alu),
61
73
  ])
62
74
 
63
75
  def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
@@ -104,10 +116,11 @@ class CStyleLanguage(Renderer):
104
116
  if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
105
117
  if isinstance(dt, PtrDType):
106
118
  return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
107
- return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
119
+ if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
120
+ return self.type_map.get(scalar:=dt.scalar(), scalar.name)
108
121
 
109
122
  def __getitem__(self, key): return self.r[key] # hacky helper
110
- def render(self, name:str, uops:list[UOp]) -> str:
123
+ def render(self, uops:list[UOp]) -> str:
111
124
  r: dict[UOp, str] = {}
112
125
  self.r = r
113
126
 
@@ -116,7 +129,11 @@ class CStyleLanguage(Renderer):
116
129
  kernel = []
117
130
  depth = 1
118
131
  c: defaultdict[str, int] = defaultdict(int)
132
+ name = "test"
119
133
  for u in uops:
134
+ if u.op is Ops.NAME:
135
+ name = u.arg
136
+ continue
120
137
  if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
121
138
  r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
122
139
  bufs[u] = (r[u], (u.dtype, False))
@@ -141,7 +158,7 @@ class CStyleLanguage(Renderer):
141
158
  assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
142
159
 
143
160
  if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
144
- if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or \
161
+ if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOM} or \
145
162
  (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")):
146
163
  r[u] = l
147
164
  else:
@@ -158,12 +175,15 @@ class CStyleLanguage(Renderer):
158
175
  return self.render_kernel(name, kernel, list(bufs.values()), uops)
159
176
 
160
177
  class ClangRenderer(CStyleLanguage):
161
- device = "CLANG"
178
+ device = "CPU"
162
179
  float4 = "(float4)"
163
180
  has_local = False
164
181
  global_max = None
165
182
  infinity = "__builtin_inff()"
166
183
  nan = '__builtin_nanf("")'
184
+ amx_tc = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt, swizzle=(None,((),(4,5,6,7,0,1,2,3))),
185
+ opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
186
+ if AMX: tensor_cores = amx_tc
167
187
 
168
188
  # language options
169
189
  buffer_suffix = " restrict"
@@ -174,14 +194,12 @@ class ClangRenderer(CStyleLanguage):
174
194
  extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
175
195
  CStyleLanguage.extra_matcher
176
196
 
177
- if AMX:
178
- tensor_cores = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
179
- swizzle=(None, ((),(4,5,6,7,0,1,2,3))), opts=("u0","u0","u0","u0","u1","u1","u1","u1"))
180
- for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
181
197
  if sys.platform == 'win32':
182
198
  kernel_prefix = "__attribute__((ms_abi)) "
183
199
  def render_vector_prefix(self, dt:DType) -> str:
184
- return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"
200
+ # round (down) to power of two
201
+ alignment = 2**int(math.log2(dt.itemsize))
202
+ return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
185
203
 
186
204
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
187
205
  prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
@@ -300,10 +318,11 @@ class CUDARenderer(CStyleLanguage):
300
318
  local_max = (1024, 1024, 64)
301
319
  shared_max = 49152
302
320
  # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
303
- tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts,
304
- swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)]]
305
- tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.half, dtype_out=dtypes.float, opts=cuda_tc_opts,
306
- swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5))))]
321
+ tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
322
+ swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
323
+ (dtypes.half,dtypes.half)]]
324
+ tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
325
+ swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
307
326
  tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
308
327
  swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
309
328
 
@@ -344,7 +363,8 @@ class CUDARenderer(CStyleLanguage):
344
363
  if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
345
364
  prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
346
365
 
347
- dt_map = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
366
+ dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
367
+ dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
348
368
  for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
349
369
  upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
350
370
  wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
@@ -353,10 +373,11 @@ class CUDARenderer(CStyleLanguage):
353
373
 
354
374
  # mma operands => {c}, {a}, {b}, {c}
355
375
  prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
356
- int *a_pk = (int *)(&a), *b_pk = (int *)(&b);\n asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32"
376
+ int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c);
377
+ asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}"
357
378
  "{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
358
379
  "{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
359
- : {", ".join([f'"+f"(c.{_nms[i]})' for i in range(n_operands[2])])}
380
+ : {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])}
360
381
  : {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
361
382
  return c;\n}}""")
362
383
 
@@ -1,10 +1,13 @@
1
1
  from typing import cast
2
- import math, struct
2
+ import math, struct, sys
3
3
  from tinygrad.renderer import Renderer
4
+ from tinygrad.renderer.cstyle import ClangRenderer
4
5
  from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
5
6
  from tinygrad.dtype import dtypes, DType, PtrDType, truncate
7
+ from tinygrad.helpers import prod, AMX
6
8
 
7
9
  def ldt(dt:DType):
10
+ if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
8
11
  if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
9
12
  return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
10
13
  dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
@@ -20,7 +23,7 @@ def lcast(input_type:DType, output_type:DType):
20
23
  if dtypes.is_float(input_type):
21
24
  if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
22
25
  if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
23
- if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
26
+ if dtypes.is_unsigned(input_type) or dtypes.is_bool(input_type):
24
27
  if dtypes.is_float(output_type): return 'uitofp'
25
28
  if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
26
29
  if dtypes.is_int(input_type):
@@ -28,6 +31,19 @@ def lcast(input_type:DType, output_type:DType):
28
31
  if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
29
32
  raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
30
33
 
34
+ # https://github.com/corsix/amx
35
+ def render_wmma(ctx, wmma: UOp) -> str:
36
+ def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
37
+
38
+ return "\n".join([
39
+ *[f' store {ldt(src.dtype)} {ctx[src]}, {ldt(src.dtype.ptr())} {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)],
40
+ f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 0})", "~{{memory}}"() #0; AMX set', # set
41
+ *[f' {ctx[wmma]}_ld{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(4,f"{ctx[wmma]}_ld{i}")} ldz' for i in range(16)], # ldz
42
+ f' {AMX(0, f"{ctx[wmma]}_ptr_amx1")} ldx\n {AMX(1, f"{ctx[wmma]}_ptr_amx0")} ldy\n {AMX(12, 0)} fma32', # ldx ldy fma
43
+ *[f' {ctx[wmma]}_st{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(5,f"{ctx[wmma]}_st{i}")} stz' for i in range(16)], # stz
44
+ f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
45
+ f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
46
+
31
47
  # llvm ops, lop[<dtype>][<op>]
32
48
  unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
33
49
  Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
@@ -36,7 +52,7 @@ flags = " nsz arcp contract afn"
36
52
  float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
37
53
  lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
38
54
 
39
- llvm_rewrite = PatternMatcher([
55
+ base_rewrite = PatternMatcher([
40
56
  # memory load/store
41
57
  (UPat(Ops.INDEX, name="x"), lambda ctx,x:
42
58
  f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
@@ -49,12 +65,22 @@ llvm_rewrite = PatternMatcher([
49
65
  (UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
50
66
  (UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
51
67
 
68
+ # GEP/VECTORIZE/CAST for float4 support
69
+ (UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
70
+ (UPat(Ops.VECTORIZE, src=UPat.var('y'), name="x"), lambda ctx,x,y:
71
+ f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
72
+ f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
73
+ (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
74
+ f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
75
+ f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
76
+ (UPat(Ops.CAST, name="x"), lambda ctx,x:
77
+ f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
78
+
52
79
  # unary/binary/ternary ops
53
- (UPat(Ops.SQRT, name="x"), lambda ctx,x:
54
- f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
55
80
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
56
81
  (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
57
- (UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
82
+ (UPat(GroupOp.Binary, name="x"), lambda ctx,x:
83
+ f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
58
84
  (UPat(Ops.WHERE, name="x"), lambda ctx,x:
59
85
  f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
60
86
 
@@ -71,6 +97,9 @@ llvm_rewrite = PatternMatcher([
71
97
  # if
72
98
  (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
73
99
  (UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
100
+
101
+ # wmma
102
+ (UPat(Ops.WMMA, name="wmma"), render_wmma),
74
103
  ])
75
104
 
76
105
  def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
@@ -79,10 +108,13 @@ def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
79
108
 
80
109
  class LLVMRenderer(Renderer):
81
110
  device = "LLVM"
82
- supports_float4 = False
111
+ abi = 'win64cc' if sys.platform == 'win32' else None
112
+ supports_float4 = True
83
113
  has_local = False
84
114
  has_shared = False
85
115
  global_max = None
116
+ string_rewrite = base_rewrite
117
+ if AMX: tensor_cores = ClangRenderer.amx_tc
86
118
 
87
119
  extra_matcher = PatternMatcher([
88
120
  # rewrite RECIP with FDIV
@@ -95,32 +127,36 @@ class LLVMRenderer(Renderer):
95
127
  (UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
96
128
  ])
97
129
 
98
- def __init__(self, abi:str|None=None):
99
- self.abi = abi
100
-
101
- def render(self, name: str, uops: list[UOp]) -> str:
130
+ def render(self, uops: list[UOp]) -> str:
102
131
  r: dict[UOp, str] = {}
103
132
  args: list[str] = []
104
133
  kernel: list[str] = []
105
134
  end_lines: dict[str, None] = {}
106
135
  vc = -1
107
136
 
108
- # prealloc all assigns
109
137
  acc_to_assign: dict[UOp, UOp] = {}
110
138
  for u in uops:
111
- if u.op is Ops.ASSIGN:
139
+ if u.op is Ops.ASSIGN: # prealloc all assigns
112
140
  vc += 1
113
141
  r[u] = r[u.src[1]] = f"%assign{vc}"
114
142
  assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
115
143
  acc_to_assign[u.src[0]] = u.src[1]
144
+ if u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
145
+ vc += 1
146
+ r[u] = f"%wmma{vc}"
147
+ for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
148
+ kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}",
149
+ f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
116
150
 
151
+ name = "test"
117
152
  for u in uops:
118
- # hack for defining sqrt function (TODO: can we get a transcendental for this?)
119
- if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
120
-
153
+ if u.op is Ops.NAME:
154
+ name = u.arg
155
+ continue
121
156
  if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
122
157
  r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
123
- args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
158
+ # NOTE: MallocAllocator promises 0x20 alignment
159
+ args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
124
160
  elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
125
161
  elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
126
162
  elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
@@ -132,7 +168,8 @@ class LLVMRenderer(Renderer):
132
168
  r[u] = f"%v{vc}"
133
169
 
134
170
  # do the rendering of the llvm ir code
135
- if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
171
+ if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None:
172
+ raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
136
173
  kernel.append(cast(str, l))
137
174
 
138
175
  # generate the phi nodes for the assigns
tinygrad/renderer/ptx.py CHANGED
@@ -65,7 +65,7 @@ def render_wmma(ctx: "PTXRenderer", wmma: UOp):
65
65
  if (elems_per_reg := 4 // src.dtype.scalar().itemsize) == 1: yield f"mov.b32 {reg}, {ctx.r[src][i]};"
66
66
  else: yield f"mov.b32 {reg}, {{{', '.join(ctx.r[src][i * elems_per_reg : (i+1) * elems_per_reg])}}};"
67
67
 
68
- dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32"}
68
+ dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32", dtypes.half: "f16"}
69
69
  yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}{" "*12}'+\
70
70
  f'{{{", ".join(ctx.wmma_r[2])}}}, {{{", ".join(ctx.wmma_r[0])}}}, {{{", ".join(ctx.wmma_r[1])}}}, {{{", ".join(ctx.wmma_r[2])}}};'
71
71
 
@@ -154,7 +154,7 @@ class PTXRenderer(Renderer):
154
154
  params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs])
155
155
  return f"{self.kernel_prefix} {function_name}(\n\t{params}\n)\n{{\n{kernel}\n}}"
156
156
 
157
- def render(self, name:str, uops:list[UOp]) -> str:
157
+ def render(self, uops:list[UOp]) -> str:
158
158
  kernel:list[str] = []
159
159
  bufs = []
160
160
 
@@ -169,7 +169,11 @@ class PTXRenderer(Renderer):
169
169
  c[prefix] += 1
170
170
  return f"%{prefix}{c[prefix]-1}"
171
171
 
172
+ name = "test"
172
173
  for u in uops:
174
+ if u.op is Ops.NAME:
175
+ name = u.arg
176
+ continue
173
177
  if u.op is Ops.VECTORIZE:
174
178
  r[u] = [cast(str,r[x]) for x in u.src]
175
179
  continue
tinygrad/renderer/wgsl.py CHANGED
@@ -25,16 +25,19 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
25
25
  val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
26
26
  return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
27
27
 
28
+ def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.half
29
+
28
30
  wgsl_matcher = PatternMatcher([
29
31
  (UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
30
32
  lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
31
- (UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype.itemsize < 4 else None),
33
+ (UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if is_packed(l.dtype) else None),
32
34
  (UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
33
- lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if l.dtype.itemsize < 4 else None),
34
- (UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if var.dtype.itemsize < 4 else None),
35
+ lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
36
+ (UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
35
37
  # TODO: why is this needed, and only for this MUL order
36
38
  (UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
37
39
  lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None),
40
+ (UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None)
38
41
  ]) + extra_pm
39
42
 
40
43
  class WGSLRenderer(CStyleLanguage):
@@ -48,38 +51,43 @@ class WGSLRenderer(CStyleLanguage):
48
51
  code_for_op = {**CStyleLanguage.code_for_op, Ops.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a})"}
49
52
  nan = "nan()"
50
53
  type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32",
51
- dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool" }
54
+ dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool", dtypes.half: "f16" }
52
55
 
53
56
  string_rewrite = PatternMatcher([
54
57
  (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "true" if x.arg else "false"),
55
58
  (UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast<u32>({x.arg})" \
56
59
  if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
57
60
  (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)}, {x.dtype.size}>;"),
58
- (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}{['&0xFF','&0xFFFF','',''][x.dtype.itemsize-1]})"),
61
+ (UPat(Ops.BITCAST, dtype=dtypes.half, name="x"), lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]" \
62
+ if x.src[0].dtype in [dtypes.short, dtypes.ushort, dtypes.uint32] else None),
63
+ (UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
64
+ (UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
65
+ if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
66
+ (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
59
67
  (UPat.load(UPat.var("b"),UPat.var("v"),UPat.var("g")),lambda ctx,b,v,g:f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[g]})"),
60
68
  (UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
61
69
  (UPat.index(UPat.var("b"), UPat.var("idx")), lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
62
70
  (UPat.store(UPat.var('b'), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
63
71
  # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
64
- f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
72
+ f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
65
73
  else f"{ctx[b]} = {ctx[v]};"),
66
74
  # fix nan check: 'a != a -> is_nan()'
67
- (UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"),
75
+ (UPat.var("a") != UPat.var("a"), lambda ctx,a: f"(min({ctx[a]}, 1.0) == 1.0 && max({ctx[a]}, -1.0) == -1.0)"),
68
76
  ]) + base_rewrite
69
77
 
70
78
  def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
71
79
  def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
72
- def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if dt.itemsize < 4 else x
73
- def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt.itemsize < 4 else self.type_map[dt.base]
80
+ def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
81
+ def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
74
82
  def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
75
83
  local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
76
84
  if not local_size: local_size = [1]
77
85
  bind_it = iter(range(len(bufs)))
78
86
  external_local_bufs = [line.lstrip() for line in kernel if "var<workgroup>" in line]
79
87
  kernel[:] = [line for line in kernel if "var<workgroup>" not in line]
80
- prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
81
- # trick to obfuscate compiler so that nan is detected properly
82
- prg += "fn is_nan(v:f32) -> bool { return min(v, 1.0) == 1.0 && max(v, -1.0) == -1.0; }\n@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
88
+ prg = "enable f16;\n" if any(uop.dtype.base == dtypes.half for uop in uops) else ""
89
+ prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
90
+ prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
83
91
  prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
84
92
  f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
85
93
  f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs])