tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,41 +1,85 @@
1
- from typing import Optional, List, Tuple, Dict, Callable, Any
2
- import functools
3
- from dataclasses import dataclass, field
1
+ from __future__ import annotations
2
+ from typing import Optional, Callable
3
+ import functools, math
4
+ from dataclasses import dataclass, field, replace
4
5
  from tinygrad.helpers import to_function_name, dedup, prod
5
- from tinygrad.ops import Ops, UOp, flops_mem, sym_infer, sint, Variable
6
+ from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
6
7
  from tinygrad.dtype import DType
7
8
 
8
9
  @dataclass(frozen=True)
9
10
  class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
10
- dims: Tuple[int,int,int] # N, M, K
11
+ dims: tuple[int,int,int] # N, M, K
12
+ threads: int # number of threads that construct the warp
13
+ elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
11
14
  dtype_in: DType # dtype for A and B
12
15
  dtype_out: DType # dtype for C and D
13
- threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
14
- reduce_axes: List[Tuple[int,int]] # list of (TC dim,amt) that constructs the shape of the reduce dim
15
- @property
16
- def early_upcast_axes(self) -> List[Tuple[int,int]]: # list of (TC dim,amt) that upcasts the threads remainders of dims [0,1]
17
- return [(d,self.dims[d]//sz) for d,sz in [(dim,prod(sz for d,sz in self.threads if d==dim)) for dim in range(2)] if self.dims[d]>sz]
18
- upcast_axes: Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]]] # list of (TC dim,amt) that upcast A, B and C
19
- st1_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for A
20
- st2_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for B
21
- expanded_shape: Optional[Tuple[int, ...]] = None
22
- opts_seq: Tuple[str,str] = ("UP","LC") # upcast input, local the thread pattern
16
+ opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifing kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
17
+ swizzle: tuple[Optional[tuple[tuple[int, ...], tuple[int, ...]]], Optional[tuple[tuple[int, ...], tuple[int, ...]]]] = (None, None)
18
+ def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
19
+ def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
20
+ def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
23
21
  def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
22
+ def __post_init__(self):
23
+ local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
24
+ assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), (
25
+ f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})")
26
+ assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
27
+ assert 2**upcast_axes == self.elements_per_thread[2], (
28
+ f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}")
29
+ assert all(len(perm[0]) == local_axes and len(perm[1]) == reduce_axes + upcast_axes for perm in self.swizzle if perm), (
30
+ f"swizzle perm should be of len (({local_axes})({reduce_axes + upcast_axes}))")
31
+
32
+ @dataclass(frozen=True)
33
+ class Estimates:
34
+ # number of FLOPS used in the Kernel
35
+ ops:sint = 0
36
+ # bytes accessed in loads and stores
37
+ lds:sint = 0
38
+ # total bytes accessed, counting only once for bytes that are accessed multiple times
39
+ mem:sint = 0
40
+ def __add__(self, o:Estimates): return Estimates(self.ops + o.ops, self.lds + o.lds, self.mem + o.mem)
41
+ def simplify(self): return Estimates(ssimplify(self.ops), ssimplify(self.lds), ssimplify(self.mem))
42
+ @staticmethod
43
+ def from_uops(uops:list[UOp], ignore_indexing=False) -> Estimates:
44
+ flops: sint = 0
45
+ lds: sint = 0
46
+ mults: sint = 1
47
+ mult_stack: list[sint] = []
48
+ dont_count: set[UOp] = set()
49
+ if ignore_indexing:
50
+ for u in uops:
51
+ if u.op in {Ops.LOAD, Ops.STORE}:
52
+ dont_count = dont_count.union(u.src[0].toposort)
53
+ if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort)
54
+ elif u.op is Ops.IF:
55
+ dont_count = dont_count.union(u.src[0].toposort)
56
+ for u in uops:
57
+ if u.op is Ops.RANGE:
58
+ mult_stack.append(mults)
59
+ mults *= (u.src[1] - u.src[0]).ssimplify()
60
+ elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
61
+ elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
62
+ elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults
63
+ elif u.op is Ops.STORE: lds += u.src[1].dtype.itemsize * mults
64
+ elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
65
+ elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
66
+ return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
24
67
 
25
68
  @dataclass
26
- class Program:
69
+ class ProgramSpec:
27
70
  name:str
28
71
  src:str
29
- dname:str
30
- uops:Optional[List[UOp]]=None
72
+ device:str
73
+ uops:Optional[list[UOp]]=None
31
74
  mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
32
75
 
33
76
  # filled in from uops (if we have uops)
34
- global_size:Optional[List[int]]=None
35
- local_size:Optional[List[int]]=None
36
- vars:List[Variable]=field(default_factory=list)
37
- globals:List[int]=field(default_factory=list)
38
- outs:List[int]=field(default_factory=list)
77
+ global_size:Optional[list[int]]=None
78
+ local_size:Optional[list[int]]=None
79
+ vars:list[Variable]=field(default_factory=list)
80
+ globals:list[int]=field(default_factory=list)
81
+ outs:list[int]=field(default_factory=list)
82
+ ins:list[int]=field(default_factory=list)
39
83
  _ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
40
84
 
41
85
  def __post_init__(self):
@@ -44,7 +88,8 @@ class Program:
44
88
  for u in self.uops:
45
89
  if u.op is Ops.DEFINE_VAR: self.vars.append(u)
46
90
  if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
47
- if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL])
91
+ if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
92
+ if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
48
93
  if u.op is Ops.SPECIAL:
49
94
  # NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
50
95
  if u.arg[0][0] == 'i': self.local_size = None
@@ -53,19 +98,17 @@ class Program:
53
98
  special_size[int(u.arg[0][-1])] = u.arg[1]
54
99
  self.vars = sorted(self.vars, key=lambda v: v.arg)
55
100
  self.outs = sorted(dedup(self.outs))
101
+ self.ins = sorted(dedup(self.ins))
56
102
  self._ran_post_init = True
57
103
 
58
- @property
59
- def op_estimate(self) -> sint: return self._ops_lds[0]
60
- @property
61
- def lds_estimate(self) -> sint: return self._ops_lds[1]
62
104
  @functools.cached_property
63
- def _ops_lds(self) -> Tuple[sint, sint]: return (0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True)
105
+ def estimates(self) -> Estimates:
106
+ return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)
64
107
 
65
108
  @functools.cached_property
66
109
  def function_name(self) -> str: return to_function_name(self.name)
67
110
 
68
- def launch_dims(self, var_vals:Dict[Variable, int]):
111
+ def launch_dims(self, var_vals:dict[Variable, int]):
69
112
  global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
70
113
  local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
71
114
  return global_size, local_size
@@ -78,12 +121,12 @@ class Renderer:
78
121
  has_local: bool = True
79
122
  has_shared: bool = True
80
123
  # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
81
- global_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
82
- local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
124
+ global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
125
+ local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
83
126
  shared_max: int = 32768
84
- tensor_cores: List[TensorCore] = []
85
- extra_matcher: Any = None
86
- code_for_op: Dict[Ops, Callable] = {}
127
+ tensor_cores: list[TensorCore] = []
128
+ extra_matcher: Optional[PatternMatcher] = None
129
+ code_for_op: dict[Ops, Callable] = {}
87
130
 
88
131
  def __reduce__(self): return self.__class__, ()
89
- def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")
132
+ def render(self, name:str, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
@@ -1,8 +1,7 @@
1
- from __future__ import annotations
2
- from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
3
- import os, math
1
+ from typing import Optional, Union, Literal, Callable, cast
2
+ import os, math, sys
4
3
  from collections import defaultdict, Counter
5
- from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
4
+ from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
6
5
  from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
7
6
  from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
8
7
  from tinygrad.renderer import Renderer, TensorCore
@@ -21,7 +20,7 @@ base_rewrite = PatternMatcher([
21
20
  (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
22
21
  (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
23
22
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
24
- (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
23
+ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
25
24
  (UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
26
25
  (UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
27
26
  (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
@@ -57,14 +56,11 @@ extra_pm = PatternMatcher([
57
56
  # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
58
57
  (UPat(Ops.BITCAST, name="x"),
59
58
  lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
60
- # gate any stores that aren't gated with ifs
61
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
62
- lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
63
59
  # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
64
60
  (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
65
61
  ])
66
62
 
67
- def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
63
+ def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
68
64
 
69
65
  class CStyleLanguage(Renderer):
70
66
  kernel_prefix: str = ""
@@ -75,13 +71,13 @@ class CStyleLanguage(Renderer):
75
71
  smem_prefix_for_cast: bool = True
76
72
  arg_int_prefix: str = "const int"
77
73
  barrier: str = ""
78
- code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
79
- extra_args: List[str] = []
74
+ code_for_workitem: dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
75
+ extra_args: list[str] = []
80
76
  float4: Optional[str] = None
81
- type_map: Dict[DType, str] = {}
77
+ type_map: dict[DType, str] = {}
82
78
  infinity: str = "INFINITY"
83
79
  nan: str = "NAN"
84
- code_for_op: Dict = {
80
+ code_for_op: dict = {
85
81
  Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
86
82
  Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
87
83
  Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
@@ -93,8 +89,8 @@ class CStyleLanguage(Renderer):
93
89
  string_rewrite = base_rewrite
94
90
  extra_matcher = extra_pm
95
91
 
96
- def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
97
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
92
+ def get_kernel_modifier(self, uops:list[UOp]) -> str: return ""
93
+ def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
98
94
  tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
99
95
  buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
100
96
  self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
@@ -105,23 +101,21 @@ class CStyleLanguage(Renderer):
105
101
 
106
102
  def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
107
103
  def render_dtype(self, dt:DType, mutable=True) -> str:
108
- if isinstance(dt, ImageDType):
109
- return f"{'write_only' if mutable else 'read_only'} image2d_t"
104
+ if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
110
105
  if isinstance(dt, PtrDType):
111
- return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + \
112
- self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
106
+ return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
113
107
  return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
114
108
 
115
109
  def __getitem__(self, key): return self.r[key] # hacky helper
116
- def render(self, name:str, uops:List[UOp]) -> str:
117
- r: Dict[UOp, str] = {}
110
+ def render(self, name:str, uops:list[UOp]) -> str:
111
+ r: dict[UOp, str] = {}
118
112
  self.r = r
119
113
 
120
114
  child_count = Counter(v for ru in uops for v in ru.src)
121
- bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
115
+ bufs: dict[UOp, tuple[str, tuple[DType, bool]]] = {}
122
116
  kernel = []
123
117
  depth = 1
124
- c: DefaultDict[str, int] = defaultdict(int)
118
+ c: defaultdict[str, int] = defaultdict(int)
125
119
  for u in uops:
126
120
  if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
127
121
  r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
@@ -130,7 +124,7 @@ class CStyleLanguage(Renderer):
130
124
 
131
125
  # mark buffers that we store to writable
132
126
  if u.op is Ops.STORE:
133
- for up in u.src[0].sparents:
127
+ for up in u.src[0].toposort:
134
128
  if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
135
129
 
136
130
  # naming
@@ -147,8 +141,8 @@ class CStyleLanguage(Renderer):
147
141
  assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
148
142
 
149
143
  if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
150
- if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST}
151
- and child_count[u] == 1 and not getenv("EXPAND_SSA")):
144
+ if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or \
145
+ (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")):
152
146
  r[u] = l
153
147
  else:
154
148
  if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
@@ -176,11 +170,16 @@ class ClangRenderer(CStyleLanguage):
176
170
  type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
177
171
  code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
178
172
  Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
173
+ # LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
174
+ extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
175
+ CStyleLanguage.extra_matcher
179
176
 
180
177
  if AMX:
181
- tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt)
182
- for dt, sz in [(dt, 64//dt.itemsize) for dt in [dtypes.float]]]
183
-
178
+ tensor_cores = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
179
+ swizzle=(None, ((),(4,5,6,7,0,1,2,3))), opts=("u0","u0","u0","u0","u1","u1","u1","u1"))
180
+ for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
181
+ if sys.platform == 'win32':
182
+ kernel_prefix = "__attribute__((ms_abi)) "
184
183
  def render_vector_prefix(self, dt:DType) -> str:
185
184
  return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"
186
185
 
@@ -192,7 +191,10 @@ class ClangRenderer(CStyleLanguage):
192
191
  '#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
193
192
  '#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
194
193
  ]
195
- prefix += [f"""{(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
194
+ # 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
195
+ # to just jump at the start of a shellcode whithout having to deal with symbols or trampolines at all. This is better than having to inline
196
+ # wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
197
+ prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
196
198
  AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
197
199
  AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
198
200
  for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
@@ -209,7 +211,8 @@ class OpenCLRenderer(CStyleLanguage):
209
211
  barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
210
212
  float4 = "(float4)"
211
213
  code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
212
- type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
214
+ type_map = { dtypes.int8: "char", dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong",
215
+ dtypes.bfloat16: "ushort" }
213
216
 
214
217
  string_rewrite = PatternMatcher([
215
218
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
@@ -223,17 +226,17 @@ class OpenCLRenderer(CStyleLanguage):
223
226
  ]) + base_rewrite
224
227
 
225
228
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
226
- if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
229
+ if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
227
230
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
228
231
 
229
232
  class IntelRenderer(OpenCLRenderer):
230
233
  device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
231
- tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]),
232
- st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
234
+ tensor_cores = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
235
+ opts=("l0","l0","l0","u1","u1","u1"), swizzle=(((4,5,6),(0,1,2,3,7,8,9)), ((0,1,2),(7,8,9,3,4,5,6))))]
233
236
 
234
237
  string_rewrite = PatternMatcher([
235
- (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
236
- (UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
238
+ (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
239
+ (UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
237
240
  ]) + OpenCLRenderer.string_rewrite
238
241
 
239
242
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
@@ -247,9 +250,9 @@ class IntelRenderer(OpenCLRenderer):
247
250
  class MetalRenderer(CStyleLanguage):
248
251
  device = "METAL"
249
252
  shared_max = 32768
250
- tensor_cores = [TensorCore(dims=(8,8,8),threads=[(0,2),(1,4),(0,2),(1,2)],expanded_shape=(2,2,2,2),upcast_axes=([(1,2)],[(1,2)],[(1,2)]),
251
- st1_pattern=(((1,1),(0,1),(1,0),(0,3)),((0,0),(0,2),(1,3),(1,2))),st2_pattern=(((0,0),(1,1),(1,2),(0,2),(1,0)),((0,1),(0,3),(1,3))),
252
- dtype_in=di,dtype_out=do,reduce_axes=[(0,8)]) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
253
+ tensor_cores = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do, opts=("u0","l0","l1","l1","l0","l1"),
254
+ swizzle=(((6,1,2,7,4),(8,0,3,5)), ((0,5,6,3,7),(1,2,4,8)))) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
255
+ (dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
253
256
  def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
254
257
 
255
258
  # language options
@@ -289,18 +292,26 @@ class MetalRenderer(CStyleLanguage):
289
292
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
290
293
 
291
294
  _nms = "xyzwabcdefghijkl"
295
+ cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
292
296
 
293
297
  class CUDARenderer(CStyleLanguage):
294
298
  device = "CUDA"
295
299
  global_max = (2147483647, 65535, 65535)
296
300
  local_max = (1024, 1024, 64)
297
301
  shared_max = 49152
298
- # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
299
- tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do, expanded_shape=(2,2,2,2,2,2),
300
- st1_pattern=(((1,1),(1,0),(0,2),(0,3),(0,4)),((1,3),(1,5),(1,2),(0,0),(0,1),(1,4))),
301
- st2_pattern=(((1,1),(1,0),(1,4),(0,0),(0,1)),((0,4),(0,2),(1,5),(0,3),(1,3),(1,2))), reduce_axes=[(0,8),(1,2)],
302
- upcast_axes=([(0,8)],[(2,2),(3,2)],[(3,2),(2,2)])) for di, do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])]
303
- def __init__(self, arch:str): self.tensor_cores, self.arch = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
302
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
303
+ tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts,
304
+ swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)]]
305
+ tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.half, dtype_out=dtypes.float, opts=cuda_tc_opts,
306
+ swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5))))]
307
+ tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
308
+ swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
309
+
310
+ tc_sm80 = tc_81616 + tc_8168_f16
311
+ if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32
312
+ tc_sm75 = tc_8168_f16
313
+ def __init__(self, arch:str):
314
+ self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch
304
315
  def __reduce__(self): return self.__class__, (self.arch,)
305
316
 
306
317
  # language options
@@ -333,7 +344,7 @@ class CUDARenderer(CStyleLanguage):
333
344
  if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
334
345
  prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
335
346
 
336
- dt_map = { dtypes.half: "f16", dtypes.bfloat16: "bf16" }
347
+ dt_map = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
337
348
  for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
338
349
  upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
339
350
  wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
@@ -351,18 +362,24 @@ class CUDARenderer(CStyleLanguage):
351
362
 
352
363
  return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
353
364
 
354
- def get_kernel_modifier(self, uops:List[UOp]) -> str:
365
+ def get_kernel_modifier(self, uops:list[UOp]) -> str:
355
366
  maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
356
367
  # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
357
368
  return f"__launch_bounds__({maxThreadsPerBlock}) "
358
369
 
370
+ def cast_float_to_bf16(x: UOp) -> UOp:
371
+ assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
372
+ x = x.bitcast(dtypes.uint)
373
+ x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
374
+ return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
375
+
359
376
  class AMDRenderer(CStyleLanguage):
360
377
  device = "AMD"
361
378
  shared_max = 65536
362
379
  # https://gpuopen.com/learn/wmma_on_rdna3/
363
- tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], dtype_in=di, dtype_out=do, reduce_axes=[(0,16)], opts_seq=("LC","UP"),
364
- upcast_axes = ([(0,16)],[(0,16)],[(1,8)]), st1_pattern=(((1,2),(0,2),(1,1),(0,1)),((1,0),(0,0))), expanded_shape=(16,2,4))
365
- for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]]
380
+ tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
381
+ opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
382
+ for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
366
383
 
367
384
  # language options
368
385
  ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
@@ -397,8 +414,7 @@ class AMDRenderer(CStyleLanguage):
397
414
  (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
398
415
  # bfloat16 casting
399
416
  (UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
400
- (UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
401
- lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
417
+ (UPat(Ops.CAST, dtypes.float, UPat.var("x", dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
402
418
  (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
403
419
 
404
420
  def render_vector_prefix(self, dtype:DType) -> str:
@@ -410,7 +426,7 @@ class AMDRenderer(CStyleLanguage):
410
426
  prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
411
427
 
412
428
  used_dtypes = uops_to_dtypes(uops)
413
- if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("struct hip_bfloat16 { unsigned short data; };")
429
+ if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
414
430
  prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
415
431
 
416
432
  for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
@@ -421,42 +437,12 @@ class AMDRenderer(CStyleLanguage):
421
437
  for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
422
438
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
423
439
 
424
- def get_kernel_modifier(self, uops:List[UOp]) -> str:
440
+ def get_kernel_modifier(self, uops:list[UOp]) -> str:
425
441
  requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
426
442
  # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
427
443
  # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
428
444
  return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
429
445
 
430
- class DSPRenderer(ClangRenderer):
431
- device = "DSP"
432
- supports_float4 = False
433
- buffer_suffix = " restrict __attribute__((align_value(128)))"
434
- kernel_prefix = "__attribute__((noinline)) "
435
- type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
436
- code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
437
- Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
438
- Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
439
-
440
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
441
- ret = super().render_kernel(function_name, kernel, bufs, uops, prefix)
442
- msrc = ['''struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency; _Bool set_dcvs_params;
443
- short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3]; };''', 'int HAP_power_set(void*, void*);',
444
- 'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
445
- 'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
446
- 'unsigned long long HAP_perf_get_time_us(void);', 'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
447
- 'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
448
- 'HAP_power_set((void*)handle, (void*)&req);']
449
- msrc += ['if ((sc>>24) != 2) return 0;']
450
- msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
451
- msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
452
- msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
453
- msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
454
- msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
455
- msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
456
- msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
457
- msrc += ["return 0; }"]
458
- return ret + '\n' + '\n'.join(msrc)
459
-
460
446
  class NVRenderer(CUDARenderer): device = "NV"
461
447
  class HIPRenderer(AMDRenderer): device = "HIP"
462
448
  class QCOMRenderer(OpenCLRenderer): device = "QCOM"
@@ -1,4 +1,4 @@
1
- from typing import List, Dict, cast
1
+ from typing import cast
2
2
  import math, struct
3
3
  from tinygrad.renderer import Renderer
4
4
  from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
@@ -60,19 +60,23 @@ llvm_rewrite = PatternMatcher([
60
60
 
61
61
  # range
62
62
  (UPat(Ops.RANGE, name="x"), lambda ctx,x:
63
- f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
64
- f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
65
- f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
63
+ f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
64
+ f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
65
+ f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
66
66
  (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
67
- f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
67
+ f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
68
68
  f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
69
- f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
69
+ f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"),
70
70
 
71
71
  # if
72
72
  (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
73
73
  (UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
74
74
  ])
75
75
 
76
+ def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
77
+ u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size))
78
+ return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype)
79
+
76
80
  class LLVMRenderer(Renderer):
77
81
  device = "LLVM"
78
82
  supports_float4 = False
@@ -85,23 +89,24 @@ class LLVMRenderer(Renderer):
85
89
  (UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
86
90
  # rewrite cast to bool to CMPNE 0
87
91
  (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
88
- # *** also in cstyle ***
89
- # gate any stores that aren't gated with ifs
90
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
91
- lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
92
92
  # rewrite MAX to CMPLT + WHERE
93
93
  (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
94
+ # rewrite bf16 CAST(LOAD) to CAST(BITCAST)
95
+ (UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
94
96
  ])
95
97
 
96
- def render(self, name: str, uops: List[UOp]) -> str:
97
- r: Dict[UOp, str] = {}
98
- args: List[str] = []
99
- kernel: List[str] = []
100
- end_lines: Dict[str, None] = {}
98
+ def __init__(self, abi:str|None=None):
99
+ self.abi = abi
100
+
101
+ def render(self, name: str, uops: list[UOp]) -> str:
102
+ r: dict[UOp, str] = {}
103
+ args: list[str] = []
104
+ kernel: list[str] = []
105
+ end_lines: dict[str, None] = {}
101
106
  vc = -1
102
107
 
103
108
  # prealloc all assigns
104
- acc_to_assign: Dict[UOp, UOp] = {}
109
+ acc_to_assign: dict[UOp, UOp] = {}
105
110
  for u in uops:
106
111
  if u.op is Ops.ASSIGN:
107
112
  vc += 1
@@ -133,10 +138,17 @@ class LLVMRenderer(Renderer):
133
138
  # generate the phi nodes for the assigns
134
139
  if u.op is Ops.RANGE:
135
140
  for x in acc_to_assign:
136
- if u in x.src: # if this range is relevent for this acc
141
+ if u in x.src: # if this range is relevant for this acc
137
142
  vc += 1
138
- kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
143
+ kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]")
139
144
  r[x] = f"%acc{vc}"
140
145
 
141
- # output the function
142
- return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys())
146
+ # output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings)
147
+ return f'''\
148
+ define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{
149
+ {chr(10).join(kernel)}
150
+ ret void
151
+ }}
152
+ {chr(10).join(end_lines.keys())}
153
+ attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
154
+ '''