tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,148 +1,160 @@
1
- from typing import Final, Dict, Callable, Any, List, Optional, Tuple
2
- import functools
3
- from llvmlite import ir # type: ignore
4
- from tinygrad.codegen.linearizer import UOps, UOp, Token, MemOp, ConstOp
5
- from tinygrad.helpers import dtypes
1
+ from typing import Final, Dict, Callable, Any, List, Optional
2
+ from llvmlite import ir
3
+ from tinygrad.codegen.linearizer import UOps, UOp
4
+ from tinygrad.dtype import DType, PtrDType, dtypes
6
5
  from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
6
+ from tinygrad.codegen.uops import UOpGraph
7
+ from tinygrad.renderer import Renderer
7
8
 
8
- from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
9
- def int_const(x): return ir.Constant(ir.IntType(64), x)
10
- render_llvm = {
11
- NumNode: lambda self,ops,ctx: int_const(self.b),
12
- MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)),
13
- DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)),
14
- ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)),
15
- LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)),
16
- SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
17
- AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
18
- }
9
+ MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
10
+
11
+ def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
19
12
 
20
13
  code_for_op: Final[Dict[Op, Callable]] = {
21
- UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)),
22
- UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)),
23
- UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)),
24
- UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=('fast',)),
25
- BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=('fast',)),
26
- BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=('fast',)),
27
- BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=('fast',)),
28
- BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=('fast',)),
29
- BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=('fast',)), ir.FloatType()),
30
- BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)),
31
- BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y),
32
- TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)),
33
- TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
34
- }
35
-
36
- dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
37
-
38
- def cast(bb, val, input_type, output_type):
14
+ UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
15
+ (builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
16
+ UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
17
+ UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
18
+ UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
19
+ UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
20
+ BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
21
+ BinaryOps.SUB: lambda builder, x, y, dtype: builder.sub(x, y) if dtypes.is_int(dtype) else builder.fsub(x, y, flags=MFLAGS),
22
+ BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
23
+ BinaryOps.DIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y) if dtypes.is_int(dtype) else builder.fdiv(x, y, flags=MFLAGS), # noqa: E501
24
+ BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
25
+ BinaryOps.CMPEQ: lambda builder, x, y, dtype: builder.icmp_unsigned("==", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("==", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("==", x, y, flags=MFLAGS), # noqa: E501
26
+ BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
27
+ BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
28
+ BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
29
+ TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
30
+
31
+ dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
32
+ dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
33
+ dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
34
+
35
+ def cast(bb, val, input_type, output_type, bitcast=False):
39
36
  if input_type == output_type: return val
40
-
41
- if output_type == dtypes.float32:
42
- if dtypes.is_int(input_type) or input_type == dtypes.bool:
43
- val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
44
- elif input_type == dtypes.bfloat16:
45
- val = bb[-1].sext(val, ir.IntType(32))
46
- val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
47
- val = bb[-1].bitcast(val, ir.FloatType())
48
- elif input_type == dtypes.float64:
49
- val = bb[-1].fptrunc(val, ir.FloatType())
50
- else:
51
- val = bb[-1].fpext(val, ir.FloatType())
52
- return val
53
-
54
- if input_type == dtypes.float32:
55
- if dtypes.is_int(output_type) or output_type == dtypes.bool:
56
- val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_unsigned(output_type) or output_type == dtypes.bool else bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
57
- elif output_type == dtypes.bfloat16:
58
- val = bb[-1].bitcast(val, ir.IntType(32))
59
- val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
60
- val = bb[-1].trunc(val, ir.IntType(16))
61
- elif output_type == dtypes.float64:
62
- val = bb[-1].fpext(val, ir.DoubleType())
63
- else:
64
- val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
65
- return val
37
+ llvm_type = dtype_to_llvm_dtype[output_type]
38
+ if bitcast: return bb[-1].bitcast(val, llvm_type)
39
+
40
+ if input_type == dtypes.bfloat16:
41
+ val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
42
+ input_type = dtypes.float32
43
+ if output_type == dtypes.bfloat16:
44
+ val = cast(bb, val, input_type, dtypes.float32)
45
+ return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
46
+
47
+ if dtypes.is_float(input_type):
48
+ if dtypes.is_float(output_type):
49
+ return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
50
+ if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
51
+ if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
52
+
53
+ if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
54
+ if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
55
+ if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
56
+ if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
57
+ if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
58
+
59
+ if dtypes.is_int(input_type):
60
+ if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
61
+ if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
62
+ if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
63
+ if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
66
64
 
67
65
  raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
68
66
 
69
- def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
70
- # all llvm stuff goes into a module
71
- module = ir.Module(name=__file__)
72
-
73
- # extract global buffers
74
- buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL}
75
- buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
76
-
77
- # create llvm function
78
- func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
79
- func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name)
80
- for a in func.args: a.add_attribute("noalias")
81
-
82
- # force llvmlite to allow us to add function attribute then add the attribute
83
- func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
84
- func.attributes.add('"no-nans-fp-math"="true"')
85
-
86
- bb = [ir.IRBuilder(func.append_basic_block("entry"))]
87
- loop_blocks = []
88
- reduce_phis: List = []
89
- # TODO: newvar probably shouldn't be optional
90
- lvars: Dict[Optional[Token], Any] = {} # this Any is an llvm type
91
- render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr]
92
-
93
- for uop,newvar,vin,args in uops:
94
- if uop == UOps.LOOP:
95
- for var in args[0]:
96
- if isinstance(var, NumNode): continue
97
- bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{var.expr}")))
98
- bb[-2].branch(bb[-1]._block)
99
-
100
- phis = []
101
- for rp in reduce_phis:
102
- incoming = lvars[rp]
103
- lvars[rp] = bb[-1].phi(ir.FloatType())
104
- lvars[rp].add_incoming(incoming, bb[-2]._block)
105
- phis.append((rp, lvars[rp]))
106
- loop_blocks.append((bb[-1], phis))
107
-
108
- lvars[var.expr] = bb[-1].phi(ir.IntType(64), name=var.expr)
109
- lvars[var.expr].add_incoming(int_const(var.min), bb[-2]._block)
110
- if uop == UOps.ENDLOOP:
111
- for var in args[0][::-1]:
112
- if isinstance(var, NumNode): continue
113
- block, phis = loop_blocks.pop()
114
- idx_p1 = bb[-1].add(lvars[var.expr], int_const(1))
115
- lvars[var.expr].add_incoming(idx_p1, bb[-1]._block)
116
- for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
117
- bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}")))
118
- bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block)
119
- if uop == UOps.LOAD:
120
- assert newvar is not None and isinstance(args, (MemOp, ConstOp))
121
- valid = args.valid.render(render_llvm, bb[-1])
122
- if isinstance(args, ConstOp):
123
- value, invalid_value = [int(args.value), int(args.invalid_value)] if dtypes.is_int(newvar.dtype) else ([bool(args.value), bool(args.invalid_value)] if newvar.dtype == dtypes.bool else [args.value, args.invalid_value]) # type: ignore
124
- if args.valid.min == 0 and args.valid.max == 1:
125
- val = bb[-1].select(valid, ir.Constant(dtype_to_llvm_dtype[newvar.dtype], value), ir.Constant(dtype_to_llvm_dtype[newvar.dtype], invalid_value))
67
+ def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
68
+
69
+ class LLVMRenderer(Renderer):
70
+ device = "LLVM"
71
+ supports_float4=False
72
+ has_local=False
73
+ has_shared=False
74
+
75
+ def render(self, name:str, uops:UOpGraph) -> str:
76
+ # all llvm stuff goes into a module
77
+ module = ir.Module(name=__file__)
78
+
79
+ # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
80
+ buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
81
+ buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
82
+
83
+ # create llvm function
84
+ func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
85
+ func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
86
+ for a in func.args:
87
+ if a.type.is_pointer: a.add_attribute("noalias")
88
+
89
+ # add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
90
+ func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
91
+ func.attributes.add('"no-nans-fp-math"="true"')
92
+
93
+ bb = [ir.IRBuilder(func.append_basic_block("entry"))]
94
+ loop_blocks: List = []
95
+ reduce_phis: List = []
96
+ # TODO: newvar probably shouldn't be optional
97
+ lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
98
+
99
+ for bufname,dtype in buf_to_dtype.items():
100
+ if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
101
+
102
+ for u in uops:
103
+ uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
104
+ if uop is UOps.STORE:
105
+ element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
106
+ if len(vin) > 3:
107
+ with bb[-1].if_then(lvars[vin[3]]):
108
+ bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
126
109
  else:
127
- val = ir.Constant(dtype_to_llvm_dtype[newvar.dtype], value if args.valid.min == 1 else invalid_value)
128
- # TODO: this is a hack. it shouldn't be const that signals this
129
- reduce_phis.append(newvar)
110
+ bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
111
+ elif uop is UOps.ENDRANGE:
112
+ loop_entry_bb, phis = loop_blocks.pop()
113
+ idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
114
+ lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
115
+ for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
116
+ bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
117
+ bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
130
118
  else:
131
- idx = args.idx.render(render_llvm, bb[-1])
132
- if args.valid.min == 0:
133
- aug_idx = bb[-1].select(valid, idx, int_const(0))
134
- val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value))
135
- else:
136
- val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
137
- val = cast(bb, val, args.memory_dtype, newvar.dtype)
138
- lvars[newvar] = val
139
- if uop == UOps.STORE:
140
- assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory"
141
- idx = args.idx.render(render_llvm, bb[-1])
142
- element = cast(bb, lvars[vin[0]], vin[0].dtype, args.memory_dtype)
143
- bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
144
- if uop == UOps.ALU:
145
- lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
146
-
147
- bb[-1].ret_void()
148
- return str(module), None, None
119
+ assert dtype is not None, f"None dtype for uop {uop}"
120
+ if uop is UOps.RANGE:
121
+ bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
122
+ bb[-2].branch(bb[-1].block)
123
+
124
+ phis = []
125
+ for rp in reduce_phis:
126
+ incoming = lvars[rp]
127
+ lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
128
+ lvars[rp].add_incoming(incoming, bb[-2].block)
129
+ phis.append((rp, lvars[rp]))
130
+
131
+ lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
132
+ lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
133
+ loop_blocks.append((bb[-1].block, phis))
134
+ elif uop is UOps.DEFINE_ACC:
135
+ lvars[u] = const(args[0], dtype)
136
+ reduce_phis.append(u)
137
+ elif uop is UOps.LOAD:
138
+ if len(vin) > 2:
139
+ aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
140
+ val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
141
+ val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
142
+ else:
143
+ val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
144
+ lvars[u] = val
145
+ elif uop is UOps.PHI:
146
+ lvars[u] = lvars[vin[1]]
147
+ # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
148
+ backward = vin[0]
149
+ while backward.uop is UOps.PHI: backward = backward.vin[0]
150
+ lvars[backward] = lvars[u]
151
+ elif uop is UOps.ALU:
152
+ lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
153
+ elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
154
+ elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
155
+ elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
156
+ elif uop is UOps.CONST: lvars[u] = const(args, dtype)
157
+ else: raise RuntimeError(f"failed to render {uop}")
158
+
159
+ bb[-1].ret_void()
160
+ return str(module)