tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,41 +1,100 @@
1
- from typing import Optional, List, Tuple, Dict, Callable, Any
2
- import functools
3
- from dataclasses import dataclass, field
1
+ from __future__ import annotations
2
+ from typing import Optional, Callable
3
+ import functools, math
4
+ from enum import Enum, auto
5
+ from dataclasses import dataclass, field, replace
4
6
  from tinygrad.helpers import to_function_name, dedup, prod
5
- from tinygrad.ops import Ops, UOp, flops_mem, sym_infer, sint, Variable
7
+ from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
6
8
  from tinygrad.dtype import DType
7
9
 
10
+ class OptOps(Enum):
11
+ TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
12
+ GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
13
+ def __lt__(self, x:OptOps): return self.value < x.value
14
+
15
+ @dataclass(frozen=True, order=True)
16
+ class Opt:
17
+ op: OptOps
18
+ axis: Optional[int] = None
19
+ arg: Optional[int | tuple] = None
20
+ def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
21
+
8
22
  @dataclass(frozen=True)
9
23
  class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
10
- dims: Tuple[int,int,int] # N, M, K
24
+ dims: tuple[int,int,int] # N, M, K
25
+ threads: int # number of threads that construct the warp
26
+ elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
11
27
  dtype_in: DType # dtype for A and B
12
28
  dtype_out: DType # dtype for C and D
13
- threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
14
- reduce_axes: List[Tuple[int,int]] # list of (TC dim,amt) that constructs the shape of the reduce dim
15
- @property
16
- def early_upcast_axes(self) -> List[Tuple[int,int]]: # list of (TC dim,amt) that upcasts the threads remainders of dims [0,1]
17
- return [(d,self.dims[d]//sz) for d,sz in [(dim,prod(sz for d,sz in self.threads if d==dim)) for dim in range(2)] if self.dims[d]>sz]
18
- upcast_axes: Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]]] # list of (TC dim,amt) that upcast A, B and C
19
- st1_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for A
20
- st2_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for B
21
- expanded_shape: Optional[Tuple[int, ...]] = None
22
- opts_seq: Tuple[str,str] = ("UP","LC") # upcast input, local the thread pattern
29
+ opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifing kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
30
+ swizzle: tuple[Optional[tuple[tuple[int, ...], tuple[int, ...]]], Optional[tuple[tuple[int, ...], tuple[int, ...]]]] = (None, None)
31
+ def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
32
+ def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
33
+ def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
23
34
  def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
35
+ def __post_init__(self):
36
+ local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
37
+ assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), (
38
+ f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})")
39
+ assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
40
+ assert 2**upcast_axes == self.elements_per_thread[2], (
41
+ f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}")
42
+ assert all(len(perm[0]) == local_axes and len(perm[1]) == reduce_axes + upcast_axes for perm in self.swizzle if perm), (
43
+ f"swizzle perm should be of len (({local_axes})({reduce_axes + upcast_axes}))")
44
+
45
+ @dataclass(frozen=True)
46
+ class Estimates:
47
+ # number of FLOPS used in the Kernel
48
+ ops:sint = 0
49
+ # bytes accessed in loads and stores
50
+ lds:sint = 0
51
+ # total bytes accessed, counting only once for bytes that are accessed multiple times
52
+ mem:sint = 0
53
+ def __add__(self, o:Estimates): return Estimates(self.ops + o.ops, self.lds + o.lds, self.mem + o.mem)
54
+ def simplify(self): return Estimates(ssimplify(self.ops), ssimplify(self.lds), ssimplify(self.mem))
55
+ @staticmethod
56
+ def from_uops(uops:list[UOp], ignore_indexing=False) -> Estimates:
57
+ flops: sint = 0
58
+ lds: sint = 0
59
+ mults: sint = 1
60
+ mult_stack: list[sint] = []
61
+ dont_count: set[UOp] = set()
62
+ if ignore_indexing:
63
+ for u in uops:
64
+ if u.op in {Ops.LOAD, Ops.STORE}:
65
+ dont_count = dont_count.union(u.src[0].toposort)
66
+ if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort)
67
+ elif u.op is Ops.IF:
68
+ dont_count = dont_count.union(u.src[0].toposort)
69
+ for u in uops:
70
+ if u.op is Ops.RANGE:
71
+ mult_stack.append(mults)
72
+ mults *= (u.src[1] - u.src[0]).ssimplify()
73
+ elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
74
+ elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
75
+ elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults
76
+ elif u.op is Ops.STORE: lds += u.src[1].dtype.itemsize * mults
77
+ elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
78
+ elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
79
+ return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
24
80
 
25
81
  @dataclass
26
- class Program:
82
+ class ProgramSpec:
27
83
  name:str
28
84
  src:str
29
- dname:str
30
- uops:Optional[List[UOp]]=None
85
+ device:str
86
+ ast:UOp # save the base ast (this is method cache key)
87
+ uops:Optional[list[UOp]]=None
88
+ applied_opts:Optional[list[Opt]]=None
31
89
  mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
32
90
 
33
91
  # filled in from uops (if we have uops)
34
- global_size:Optional[List[int]]=None
35
- local_size:Optional[List[int]]=None
36
- vars:List[Variable]=field(default_factory=list)
37
- globals:List[int]=field(default_factory=list)
38
- outs:List[int]=field(default_factory=list)
92
+ global_size:Optional[list[int]]=None
93
+ local_size:Optional[list[int]]=None
94
+ vars:list[Variable]=field(default_factory=list)
95
+ globals:list[int]=field(default_factory=list)
96
+ outs:list[int]=field(default_factory=list)
97
+ ins:list[int]=field(default_factory=list)
39
98
  _ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
40
99
 
41
100
  def __post_init__(self):
@@ -44,7 +103,8 @@ class Program:
44
103
  for u in self.uops:
45
104
  if u.op is Ops.DEFINE_VAR: self.vars.append(u)
46
105
  if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
47
- if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL])
106
+ if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
107
+ if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
48
108
  if u.op is Ops.SPECIAL:
49
109
  # NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
50
110
  if u.arg[0][0] == 'i': self.local_size = None
@@ -53,19 +113,17 @@ class Program:
53
113
  special_size[int(u.arg[0][-1])] = u.arg[1]
54
114
  self.vars = sorted(self.vars, key=lambda v: v.arg)
55
115
  self.outs = sorted(dedup(self.outs))
116
+ self.ins = sorted(dedup(self.ins))
56
117
  self._ran_post_init = True
57
118
 
58
- @property
59
- def op_estimate(self) -> sint: return self._ops_lds[0]
60
- @property
61
- def lds_estimate(self) -> sint: return self._ops_lds[1]
62
119
  @functools.cached_property
63
- def _ops_lds(self) -> Tuple[sint, sint]: return (0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True)
120
+ def estimates(self) -> Estimates:
121
+ return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)
64
122
 
65
123
  @functools.cached_property
66
124
  def function_name(self) -> str: return to_function_name(self.name)
67
125
 
68
- def launch_dims(self, var_vals:Dict[Variable, int]):
126
+ def launch_dims(self, var_vals:dict[Variable, int]):
69
127
  global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
70
128
  local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
71
129
  return global_size, local_size
@@ -78,12 +136,13 @@ class Renderer:
78
136
  has_local: bool = True
79
137
  has_shared: bool = True
80
138
  # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
81
- global_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
82
- local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
139
+ global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
140
+ local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
83
141
  shared_max: int = 32768
84
- tensor_cores: List[TensorCore] = []
85
- extra_matcher: Any = None
86
- code_for_op: Dict[Ops, Callable] = {}
142
+ tensor_cores: list[TensorCore] = []
143
+ pre_matcher: Optional[PatternMatcher] = None
144
+ extra_matcher: Optional[PatternMatcher] = None
145
+ code_for_op: dict[Ops, Callable] = {}
87
146
 
88
147
  def __reduce__(self): return self.__class__, ()
89
- def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")
148
+ def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
@@ -1,11 +1,11 @@
1
- from __future__ import annotations
2
- from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
3
- import os, math
1
+ from typing import Optional, Union, Literal, Callable, cast
2
+ import os, math, sys
4
3
  from collections import defaultdict, Counter
5
- from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
4
+ from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
6
5
  from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
7
6
  from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
8
7
  from tinygrad.renderer import Renderer, TensorCore
8
+ from tinygrad.codegen.devectorizer import no_vectorized_alu
9
9
 
10
10
  base_rewrite = PatternMatcher([
11
11
  (UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
@@ -18,10 +18,12 @@ base_rewrite = PatternMatcher([
18
18
  lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
19
19
  (UPat(Ops.VECTORIZE, name="x"),
20
20
  lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
21
- (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
21
+ (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CPU', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
22
+ (UPat(Ops.CAST, name="x"), lambda ctx,x:
23
+ f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
22
24
  (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
23
25
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
24
- (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
26
+ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
25
27
  (UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
26
28
  (UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
27
29
  (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
@@ -50,21 +52,27 @@ base_rewrite = PatternMatcher([
50
52
  (UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
51
53
  *([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
52
54
  (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
53
- (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
55
+ (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CPU', 'DSP'} else \
56
+ f".{'xyzwabcd'[x.arg[0]]}")),
57
+ # custom passes through with format
58
+ (UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
54
59
  ])
55
60
 
56
61
  extra_pm = PatternMatcher([
57
62
  # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
58
63
  (UPat(Ops.BITCAST, name="x"),
59
64
  lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
60
- # gate any stores that aren't gated with ifs
61
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
62
- lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
63
65
  # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
64
66
  (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
67
+ # devectorize any bools
68
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
69
+ # CAST (from bool) can't be vectorized
70
+ (UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
71
+ # WHERE can't be vectorized
72
+ (UPat(Ops.WHERE, name="alu"), no_vectorized_alu),
65
73
  ])
66
74
 
67
- def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
75
+ def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
68
76
 
69
77
  class CStyleLanguage(Renderer):
70
78
  kernel_prefix: str = ""
@@ -75,13 +83,13 @@ class CStyleLanguage(Renderer):
75
83
  smem_prefix_for_cast: bool = True
76
84
  arg_int_prefix: str = "const int"
77
85
  barrier: str = ""
78
- code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
79
- extra_args: List[str] = []
86
+ code_for_workitem: dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
87
+ extra_args: list[str] = []
80
88
  float4: Optional[str] = None
81
- type_map: Dict[DType, str] = {}
89
+ type_map: dict[DType, str] = {}
82
90
  infinity: str = "INFINITY"
83
91
  nan: str = "NAN"
84
- code_for_op: Dict = {
92
+ code_for_op: dict = {
85
93
  Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
86
94
  Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
87
95
  Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
@@ -93,8 +101,8 @@ class CStyleLanguage(Renderer):
93
101
  string_rewrite = base_rewrite
94
102
  extra_matcher = extra_pm
95
103
 
96
- def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
97
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
104
+ def get_kernel_modifier(self, uops:list[UOp]) -> str: return ""
105
+ def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
98
106
  tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
99
107
  buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
100
108
  self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
@@ -105,24 +113,27 @@ class CStyleLanguage(Renderer):
105
113
 
106
114
  def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
107
115
  def render_dtype(self, dt:DType, mutable=True) -> str:
108
- if isinstance(dt, ImageDType):
109
- return f"{'write_only' if mutable else 'read_only'} image2d_t"
116
+ if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
110
117
  if isinstance(dt, PtrDType):
111
- return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + \
112
- self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
113
- return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
118
+ return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
119
+ if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
120
+ return self.type_map.get(scalar:=dt.scalar(), scalar.name)
114
121
 
115
122
  def __getitem__(self, key): return self.r[key] # hacky helper
116
- def render(self, name:str, uops:List[UOp]) -> str:
117
- r: Dict[UOp, str] = {}
123
+ def render(self, uops:list[UOp]) -> str:
124
+ r: dict[UOp, str] = {}
118
125
  self.r = r
119
126
 
120
127
  child_count = Counter(v for ru in uops for v in ru.src)
121
- bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
128
+ bufs: dict[UOp, tuple[str, tuple[DType, bool]]] = {}
122
129
  kernel = []
123
130
  depth = 1
124
- c: DefaultDict[str, int] = defaultdict(int)
131
+ c: defaultdict[str, int] = defaultdict(int)
132
+ name = "test"
125
133
  for u in uops:
134
+ if u.op is Ops.NAME:
135
+ name = u.arg
136
+ continue
126
137
  if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
127
138
  r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
128
139
  bufs[u] = (r[u], (u.dtype, False))
@@ -130,7 +141,7 @@ class CStyleLanguage(Renderer):
130
141
 
131
142
  # mark buffers that we store to writable
132
143
  if u.op is Ops.STORE:
133
- for up in u.src[0].sparents:
144
+ for up in u.src[0].toposort:
134
145
  if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
135
146
 
136
147
  # naming
@@ -147,8 +158,8 @@ class CStyleLanguage(Renderer):
147
158
  assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
148
159
 
149
160
  if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
150
- if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST}
151
- and child_count[u] == 1 and not getenv("EXPAND_SSA")):
161
+ if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOM} or \
162
+ (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")):
152
163
  r[u] = l
153
164
  else:
154
165
  if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
@@ -164,25 +175,31 @@ class CStyleLanguage(Renderer):
164
175
  return self.render_kernel(name, kernel, list(bufs.values()), uops)
165
176
 
166
177
  class ClangRenderer(CStyleLanguage):
167
- device = "CLANG"
178
+ device = "CPU"
168
179
  float4 = "(float4)"
169
180
  has_local = False
170
181
  global_max = None
171
182
  infinity = "__builtin_inff()"
172
183
  nan = '__builtin_nanf("")'
184
+ amx_tc = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt, swizzle=(None,((),(4,5,6,7,0,1,2,3))),
185
+ opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
186
+ if AMX: tensor_cores = amx_tc
173
187
 
174
188
  # language options
175
189
  buffer_suffix = " restrict"
176
190
  type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
177
191
  code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
178
192
  Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
193
+ # LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
194
+ extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
195
+ CStyleLanguage.extra_matcher
179
196
 
180
- if AMX:
181
- tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt)
182
- for dt, sz in [(dt, 64//dt.itemsize) for dt in [dtypes.float]]]
183
-
197
+ if sys.platform == 'win32':
198
+ kernel_prefix = "__attribute__((ms_abi)) "
184
199
  def render_vector_prefix(self, dt:DType) -> str:
185
- return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"
200
+ # round (down) to power of two
201
+ alignment = 2**int(math.log2(dt.itemsize))
202
+ return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
186
203
 
187
204
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
188
205
  prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
@@ -192,7 +209,10 @@ class ClangRenderer(CStyleLanguage):
192
209
  '#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
193
210
  '#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
194
211
  ]
195
- prefix += [f"""{(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
212
+ # 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
213
+ # to just jump at the start of a shellcode whithout having to deal with symbols or trampolines at all. This is better than having to inline
214
+ # wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
215
+ prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
196
216
  AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
197
217
  AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
198
218
  for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
@@ -209,7 +229,8 @@ class OpenCLRenderer(CStyleLanguage):
209
229
  barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
210
230
  float4 = "(float4)"
211
231
  code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
212
- type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
232
+ type_map = { dtypes.int8: "char", dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong",
233
+ dtypes.bfloat16: "ushort" }
213
234
 
214
235
  string_rewrite = PatternMatcher([
215
236
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
@@ -223,17 +244,17 @@ class OpenCLRenderer(CStyleLanguage):
223
244
  ]) + base_rewrite
224
245
 
225
246
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
226
- if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
247
+ if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
227
248
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
228
249
 
229
250
  class IntelRenderer(OpenCLRenderer):
230
251
  device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
231
- tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]),
232
- st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
252
+ tensor_cores = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
253
+ opts=("l0","l0","l0","u1","u1","u1"), swizzle=(((4,5,6),(0,1,2,3,7,8,9)), ((0,1,2),(7,8,9,3,4,5,6))))]
233
254
 
234
255
  string_rewrite = PatternMatcher([
235
- (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
236
- (UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
256
+ (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
257
+ (UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
237
258
  ]) + OpenCLRenderer.string_rewrite
238
259
 
239
260
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
@@ -247,9 +268,9 @@ class IntelRenderer(OpenCLRenderer):
247
268
  class MetalRenderer(CStyleLanguage):
248
269
  device = "METAL"
249
270
  shared_max = 32768
250
- tensor_cores = [TensorCore(dims=(8,8,8),threads=[(0,2),(1,4),(0,2),(1,2)],expanded_shape=(2,2,2,2),upcast_axes=([(1,2)],[(1,2)],[(1,2)]),
251
- st1_pattern=(((1,1),(0,1),(1,0),(0,3)),((0,0),(0,2),(1,3),(1,2))),st2_pattern=(((0,0),(1,1),(1,2),(0,2),(1,0)),((0,1),(0,3),(1,3))),
252
- dtype_in=di,dtype_out=do,reduce_axes=[(0,8)]) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
271
+ tensor_cores = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do, opts=("u0","l0","l1","l1","l0","l1"),
272
+ swizzle=(((6,1,2,7,4),(8,0,3,5)), ((0,5,6,3,7),(1,2,4,8)))) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
273
+ (dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
253
274
  def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
254
275
 
255
276
  # language options
@@ -289,18 +310,27 @@ class MetalRenderer(CStyleLanguage):
289
310
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
290
311
 
291
312
  _nms = "xyzwabcdefghijkl"
313
+ cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
292
314
 
293
315
  class CUDARenderer(CStyleLanguage):
294
316
  device = "CUDA"
295
317
  global_max = (2147483647, 65535, 65535)
296
318
  local_max = (1024, 1024, 64)
297
319
  shared_max = 49152
298
- # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
299
- tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do, expanded_shape=(2,2,2,2,2,2),
300
- st1_pattern=(((1,1),(1,0),(0,2),(0,3),(0,4)),((1,3),(1,5),(1,2),(0,0),(0,1),(1,4))),
301
- st2_pattern=(((1,1),(1,0),(1,4),(0,0),(0,1)),((0,4),(0,2),(1,5),(0,3),(1,3),(1,2))), reduce_axes=[(0,8),(1,2)],
302
- upcast_axes=([(0,8)],[(2,2),(3,2)],[(3,2),(2,2)])) for di, do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])]
303
- def __init__(self, arch:str): self.tensor_cores, self.arch = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
320
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
321
+ tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
322
+ swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
323
+ (dtypes.half,dtypes.half)]]
324
+ tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
325
+ swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
326
+ tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
327
+ swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
328
+
329
+ tc_sm80 = tc_81616 + tc_8168_f16
330
+ if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32
331
+ tc_sm75 = tc_8168_f16
332
+ def __init__(self, arch:str):
333
+ self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch
304
334
  def __reduce__(self): return self.__class__, (self.arch,)
305
335
 
306
336
  # language options
@@ -333,7 +363,8 @@ class CUDARenderer(CStyleLanguage):
333
363
  if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
334
364
  prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
335
365
 
336
- dt_map = { dtypes.half: "f16", dtypes.bfloat16: "bf16" }
366
+ dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
367
+ dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
337
368
  for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
338
369
  upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
339
370
  wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
@@ -342,27 +373,34 @@ class CUDARenderer(CStyleLanguage):
342
373
 
343
374
  # mma operands => {c}, {a}, {b}, {c}
344
375
  prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
345
- int *a_pk = (int *)(&a), *b_pk = (int *)(&b);\n asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32"
376
+ int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c);
377
+ asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}"
346
378
  "{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
347
379
  "{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
348
- : {", ".join([f'"+f"(c.{_nms[i]})' for i in range(n_operands[2])])}
380
+ : {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])}
349
381
  : {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
350
382
  return c;\n}}""")
351
383
 
352
384
  return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
353
385
 
354
- def get_kernel_modifier(self, uops:List[UOp]) -> str:
386
+ def get_kernel_modifier(self, uops:list[UOp]) -> str:
355
387
  maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
356
388
  # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
357
389
  return f"__launch_bounds__({maxThreadsPerBlock}) "
358
390
 
391
+ def cast_float_to_bf16(x: UOp) -> UOp:
392
+ assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
393
+ x = x.bitcast(dtypes.uint)
394
+ x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
395
+ return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
396
+
359
397
  class AMDRenderer(CStyleLanguage):
360
398
  device = "AMD"
361
399
  shared_max = 65536
362
400
  # https://gpuopen.com/learn/wmma_on_rdna3/
363
- tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], dtype_in=di, dtype_out=do, reduce_axes=[(0,16)], opts_seq=("LC","UP"),
364
- upcast_axes = ([(0,16)],[(0,16)],[(1,8)]), st1_pattern=(((1,2),(0,2),(1,1),(0,1)),((1,0),(0,0))), expanded_shape=(16,2,4))
365
- for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]]
401
+ tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
402
+ opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
403
+ for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
366
404
 
367
405
  # language options
368
406
  ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
@@ -397,8 +435,7 @@ class AMDRenderer(CStyleLanguage):
397
435
  (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
398
436
  # bfloat16 casting
399
437
  (UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
400
- (UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
401
- lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
438
+ (UPat(Ops.CAST, dtypes.float, UPat.var("x", dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
402
439
  (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
403
440
 
404
441
  def render_vector_prefix(self, dtype:DType) -> str:
@@ -410,7 +447,7 @@ class AMDRenderer(CStyleLanguage):
410
447
  prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
411
448
 
412
449
  used_dtypes = uops_to_dtypes(uops)
413
- if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("struct hip_bfloat16 { unsigned short data; };")
450
+ if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
414
451
  prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
415
452
 
416
453
  for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
@@ -421,42 +458,12 @@ class AMDRenderer(CStyleLanguage):
421
458
  for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
422
459
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
423
460
 
424
- def get_kernel_modifier(self, uops:List[UOp]) -> str:
461
+ def get_kernel_modifier(self, uops:list[UOp]) -> str:
425
462
  requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
426
463
  # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
427
464
  # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
428
465
  return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
429
466
 
430
- class DSPRenderer(ClangRenderer):
431
- device = "DSP"
432
- supports_float4 = False
433
- buffer_suffix = " restrict __attribute__((align_value(128)))"
434
- kernel_prefix = "__attribute__((noinline)) "
435
- type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
436
- code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
437
- Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
438
- Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
439
-
440
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
441
- ret = super().render_kernel(function_name, kernel, bufs, uops, prefix)
442
- msrc = ['''struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency; _Bool set_dcvs_params;
443
- short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3]; };''', 'int HAP_power_set(void*, void*);',
444
- 'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
445
- 'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
446
- 'unsigned long long HAP_perf_get_time_us(void);', 'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
447
- 'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
448
- 'HAP_power_set((void*)handle, (void*)&req);']
449
- msrc += ['if ((sc>>24) != 2) return 0;']
450
- msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
451
- msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
452
- msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
453
- msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
454
- msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
455
- msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
456
- msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
457
- msrc += ["return 0; }"]
458
- return ret + '\n' + '\n'.join(msrc)
459
-
460
467
  class NVRenderer(CUDARenderer): device = "NV"
461
468
  class HIPRenderer(AMDRenderer): device = "HIP"
462
469
  class QCOMRenderer(OpenCLRenderer): device = "QCOM"