tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,9 @@
1
1
  from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
2
2
  import struct, math
3
3
  from collections import defaultdict
4
- from tinygrad.helpers import DEBUG
5
4
  from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
6
5
  from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
7
- from tinygrad.codegen.uops import UOps, UOp, UOpGraph, PatternMatcher, UPat
6
+ from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
8
7
  from tinygrad.renderer import Renderer, TensorCore
9
8
 
10
9
  def render_val(x, dtype):
@@ -14,14 +13,85 @@ def render_val(x, dtype):
14
13
  return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
15
14
  return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
16
15
 
16
+ asm_for_op: Dict[Op, Callable] = {
17
+ UnaryOps.NEG: lambda d,a,dt,name:
18
+ f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) else f"neg.{name} {d}, {a};",
19
+ UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
20
+ UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
21
+ UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
22
+ BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
23
+ BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
24
+ BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
25
+ BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
26
+ BinaryOps.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
27
+ BinaryOps.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
28
+ BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
29
+ BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
30
+ BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
31
+ TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
32
+ TernaryOps.WHERE: lambda d,a,b,c,dt,name:
33
+ f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
34
+ }
35
+
36
+ supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
37
+ shiftable_consts = set([2**i for i in range(64)])
38
+ ptx_matcher = PatternMatcher([
39
+ (UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
40
+ src=[UPat(UOps.CONST, name="const"), UPat(name="mul")]),
41
+ lambda root, mul, const: UOp(UOps.ALU, root.dtype,
42
+ (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
43
+ (UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
44
+ src=[UPat(UOps.CONST, name="const"), UPat(name="div")]),
45
+ lambda root, div, const: UOp(UOps.ALU, root.dtype,
46
+ (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None),
47
+ (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
48
+ (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
49
+ lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
50
+ (UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
51
+ lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
52
+ *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
53
+ lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half)))
54
+ for op in asm_for_op.keys() if op not in supports_half],
55
+ (UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
56
+ lambda x: UOp(UOps.ALU, dtypes.uint8, tuple(s.cast(dtypes.uint8) for s in x.src), x.arg).cast(dtypes.bool)),
57
+ (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
58
+ lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)),
59
+ (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
60
+ lambda root: UOp(root.op, dtypes.uint8, root.src, root.arg).cast(dtypes.bool)),
61
+ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
62
+ lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
63
+ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
64
+ lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
65
+ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
66
+ lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)),
67
+ # ptr_ar (load/store)
68
+ (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
69
+ UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
70
+ lambda root, alu, const: UOp(root.op, root.dtype,
71
+ (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
72
+ const.const(root.src[0].dtype.itemsize)*const)+root.src[2:])),
73
+ (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
74
+ UPat(UOps.CONST, name="const"))),
75
+ lambda root, const: UOp(root.op, root.dtype,
76
+ (root.src[0].cast(dtypes.int64),
77
+ UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
78
+ (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
79
+ UPat(name="alu"))), # no const here
80
+ lambda root, alu: UOp(root.op, root.dtype,
81
+ (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
82
+ UOp.const(dtypes.int64, 0))+root.src[2:])),
83
+ ])
84
+
17
85
  class PTXRenderer(Renderer):
18
86
  device = "CUDA"
19
87
  suffix = "PTX"
20
88
  global_max = (2147483647, 65535, 65535)
21
89
  local_max = (1024, 1024, 64)
22
90
  shared_max = 49152
23
- tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
24
- def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
91
+ tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
92
+ code_for_op = asm_for_op
93
+ extra_matcher = ptx_matcher
94
+ def __init__(self, arch:str, device="CUDA"): self.device, self.tensor_cores = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
25
95
 
26
96
  # language options
27
97
  kernel_prefix = """.version VERSION
@@ -29,29 +99,7 @@ class PTXRenderer(Renderer):
29
99
  .address_size 64
30
100
  .visible .entry"""
31
101
  barrier = "bar.sync\t0;"
32
- gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
33
- gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
34
- lid = [f'%tid.{chr(120+i)}' for i in range(3)]
35
- asm_for_op: Dict[Op, Callable] = {
36
- UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) \
37
- else f"neg.{name} {d}, {a};",
38
- UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
39
- UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
40
- UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
41
- BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
42
- BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
43
- BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
44
- BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
45
- BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
46
- BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
47
- BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
48
- BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
49
- TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
50
- TernaryOps.WHERE: lambda d,a,b,c,dt,name:
51
- f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
52
- }
53
- supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
54
- TernaryOps.WHERE]
102
+ supports_half = supports_half
55
103
  # HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
56
104
  types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
57
105
  dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
@@ -98,13 +146,10 @@ class PTXRenderer(Renderer):
98
146
  '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
99
147
  "\n}")
100
148
 
101
- def render(self, name:str, uops:UOpGraph) -> str:
149
+ def render(self, name:str, uops:List[UOp]) -> str:
102
150
  kernel:List[str] = []
103
151
  bufs = []
104
152
 
105
- uops.linearize(ptx_matcher)
106
- if DEBUG >= 4: uops.print()
107
-
108
153
  def kk(*s: str): kernel.append("\n".join(s))
109
154
 
110
155
  c: DefaultDict[str, int] = defaultdict(int)
@@ -133,14 +178,14 @@ class PTXRenderer(Renderer):
133
178
  uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
134
179
  if uop is UOps.IF:
135
180
  assert src[0].dtype is not None
136
- kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
181
+ kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
137
182
  elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
138
183
  elif uop is UOps.ENDRANGE:
139
- kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
140
- self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
184
+ kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
185
+ self.code_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
141
186
  kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
142
187
  elif uop is UOps.ENDIF:
143
- kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
188
+ kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
144
189
  elif uop is UOps.STORE:
145
190
  assert src[0].dtype is not None and src[2].dtype is not None
146
191
  assert src[0].dtype == dtypes.int64, "store isn't int64"
@@ -156,58 +201,54 @@ class PTXRenderer(Renderer):
156
201
  if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
157
202
  elif uop is UOps.ALU:
158
203
  assert src[0].dtype is not None
159
- if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
160
- # pass in the other dtype here
161
- kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
162
- else:
163
- kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
204
+ src_dtype = src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
205
+ kk(self.code_for_op[args](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
164
206
  elif uop is UOps.DEFINE_ACC:
165
207
  if dtype.count > 1:
166
208
  r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
167
- for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
168
- else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
209
+ for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};")
210
+ else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
169
211
  elif uop is UOps.SPECIAL:
170
- assert args[1][0] != "i", "idx not supported"
171
- kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
172
- r[u] = "%" + args[1]
173
- kernel = [f".reg .u32 %{args[1]};"] + kernel
174
- elif uop is UOps.CONST:
175
- if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
176
- else: r[u] = const(args, dtype, mov=True)
212
+ assert args[0][0] != "i", "idx not supported"
213
+ kk(f"mov.u32 %{args[0]}, %{'ctaid' if args[0][0] == 'g' else 'tid'}.{chr(120+int(args[0][-1]))};")
214
+ r[u] = "%" + args[0]
215
+ kernel = [f".reg .u32 %{args[0]};"] + kernel
216
+ elif uop is UOps.DEFINE_VAR:
217
+ bufs.append((args.expr, dtype))
218
+ r[u] = f"%{args.expr}"
219
+ kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
220
+ elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
177
221
  elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
178
222
  elif uop is UOps.LOAD:
179
223
  assert src[0].dtype == dtypes.int64, "load isn't int64"
180
224
  assert src[1].op is UOps.CONST, f"load isn't const {u}"
181
225
  mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
226
+ has_gate = len(src) > 3 and src[3].op is UOps.ALU
182
227
  if dtype.count > 1:
183
228
  r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
184
- if(len(src)>3):
229
+ if has_gate:
185
230
  for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
186
- kk((f"@{r[src[2]]}"if len(src) > 3 else "")
231
+ kk((f"@{r[src[3]]}"if has_gate else "")
187
232
  + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
188
233
  else:
189
- kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
190
- alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
234
+ kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[3]] if has_gate else None,
235
+ alt=r[src[2]] if has_gate else None, ss=mem_type, offset=src[1].arg))
191
236
  elif uop is UOps.PHI:
192
237
  if dtype.count > 1:
193
238
  for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
194
- else:
195
- kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
239
+ else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};")
196
240
  r[u] = r[src[0]]
241
+ # NOTE: casting to str is fine because you can't vectorize a vectorize
242
+ elif uop is UOps.VECTORIZE: r[u] = [cast(str,r[x]) for x in src]
197
243
  elif uop in {UOps.CAST, UOps.BITCAST}:
198
- assert src[0].dtype is not None
199
- if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
200
- else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
244
+ assert src[0].dtype is not None and dtype.count == 1
245
+ _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
201
246
  elif uop is UOps.DEFINE_LOCAL:
202
247
  # TODO: we should sum these, and fetch 0xC000 from somewhere
203
248
  assert args[1]*dtype.itemsize <= 0xC000, "too large local"
204
249
  kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
205
- elif uop is UOps.DEFINE_VAR:
206
- bufs.append((args.expr, dtype))
207
- r[u] = f"%{args.expr}"
208
- kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
209
250
  elif uop is UOps.DEFINE_GLOBAL:
210
- bufs.append((nm:=f"data{args[0]}", dtype))
251
+ bufs.append((nm:=f"data{args}", dtype))
211
252
  r[u] = f"%{nm}"
212
253
  dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
213
254
  kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
@@ -224,46 +265,3 @@ class PTXRenderer(Renderer):
224
265
 
225
266
  return self.render_kernel(kernel, name, bufs, c.items())
226
267
 
227
- ptx_matcher = PatternMatcher([
228
- (UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
229
- src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
230
- lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL)),
231
- (UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
232
- src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
233
- lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR)),
234
- (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
235
- (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
236
- lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
237
- (UPat(UOps.ALU, BinaryOps.ADD,
238
- [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
239
- lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
240
- *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
241
- lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
242
- for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
243
- (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
244
- lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
245
- (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
246
- lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
247
- (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
248
- lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
249
- (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
250
- lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
251
- (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
252
- lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
253
- # ptr_ar (load/store)
254
- (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
255
- UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
256
- lambda root, alu, const: UOp(root.op, root.dtype,
257
- (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
258
- UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
259
- (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
260
- UPat(UOps.CONST, name="const"))),
261
- lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
262
- UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
263
- )+root.src[2:])),
264
- (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
265
- UPat(name="alu"))), # no const here
266
- lambda root, alu: UOp(root.op, root.dtype,
267
- (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
268
- UOp.const(dtypes.int64, 0))+root.src[2:])),
269
- ])