tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,31 @@
1
1
  from typing import Final, Dict, Callable, Any, List, Optional
2
2
  from llvmlite import ir
3
- from tinygrad.codegen.linearizer import UOps, UOp
4
3
  from tinygrad.dtype import DType, PtrDType, dtypes
5
4
  from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
5
+ from tinygrad.codegen.uops import UOps, UOp, UOpGraph
6
+ from tinygrad.renderer import Renderer
6
7
 
7
8
  MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
8
9
 
9
10
  def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
10
11
 
11
12
  code_for_op: Final[Dict[Op, Callable]] = {
12
- UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS), # noqa: E501
13
- UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
14
- UnaryOps.LOG2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
15
- UnaryOps.SIN: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
16
- UnaryOps.SQRT: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
17
- BinaryOps.ADD: lambda builder, x, y, var_dtype: builder.or_(x, y) if var_dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(var_dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
18
- BinaryOps.SUB: lambda builder, x, y, var_dtype: builder.sub(x, y) if dtypes.is_int(var_dtype) else builder.fsub(x, y, flags=MFLAGS),
19
- BinaryOps.MUL: lambda builder, x, y, var_dtype: # TOOD should we use umul_with_overflow?
20
- builder.mul(x, y) if is_bool_or_unsigned(var_dtype) or dtypes.is_int(var_dtype) else builder.fmul(x, y, flags=MFLAGS),
21
- BinaryOps.DIV: lambda builder, x, y, var_dtype:
22
- builder.udiv(x, y) if is_bool_or_unsigned(var_dtype) else builder.sdiv(x, y) if dtypes.is_int(var_dtype) else builder.fdiv(x, y, flags=MFLAGS),
23
- BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
24
- BinaryOps.CMPEQ: lambda builder, x, y, var_dtype: builder.icmp_unsigned("==", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("==", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("==", x, y, flags=MFLAGS), # noqa: E501
25
- BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
26
- BinaryOps.MOD: lambda builder, x, y, var_dtype:
27
- builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y),
28
- BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y),
29
- TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS) \
30
- if dtypes.is_float(var_dtype) else builder.add(builder.mul(x, y), z),
31
- TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z),
32
- }
13
+ UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
14
+ (builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
15
+ UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
16
+ UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
17
+ UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
18
+ UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
19
+ UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
20
+ BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
21
+ BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
22
+ BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
23
+ BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
24
+ BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
25
+ BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
26
+ BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
27
+ BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
28
+ TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
33
29
 
34
30
  dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
35
31
  dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
@@ -37,7 +33,8 @@ dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dt
37
33
 
38
34
  def cast(bb, val, input_type, output_type, bitcast=False):
39
35
  if input_type == output_type: return val
40
- if bitcast: return bb[-1].bitcast(val, dtype_to_llvm_dtype[output_type])
36
+ llvm_type = dtype_to_llvm_dtype[output_type]
37
+ if bitcast: return bb[-1].bitcast(val, llvm_type)
41
38
 
42
39
  if input_type == dtypes.bfloat16:
43
40
  val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
@@ -48,118 +45,116 @@ def cast(bb, val, input_type, output_type, bitcast=False):
48
45
 
49
46
  if dtypes.is_float(input_type):
50
47
  if dtypes.is_float(output_type):
51
- if output_type.itemsize > input_type.itemsize: return bb[-1].fpext(val, dtype_to_llvm_dtype[output_type])
52
- return bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
53
- if dtypes.is_int(output_type):
54
- if dtypes.is_unsigned(output_type): return bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type])
55
- return bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
48
+ return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
49
+ if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
56
50
  if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
57
51
 
58
52
  if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
59
53
  if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
60
54
  if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
61
- if dtypes.is_int(output_type):
62
- if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type])
63
- return bb[-1].zext(val, dtype_to_llvm_dtype[output_type])
55
+ if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
64
56
  if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
65
57
 
66
58
  if dtypes.is_int(input_type):
67
59
  if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
68
- if dtypes.is_float(output_type): return bb[-1].sitofp(val, dtype_to_llvm_dtype[output_type])
69
- if dtypes.is_int(output_type):
70
- if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type])
71
- return bb[-1].sext(val, dtype_to_llvm_dtype[output_type])
60
+ if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
61
+ if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
72
62
  if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
73
63
 
74
64
  raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
75
65
 
76
- def const(args, dtype):
77
- # TODO: remove int from int(args) once const args conform with dtype
78
- return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args)
79
-
80
- def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
81
- # all llvm stuff goes into a module
82
- module = ir.Module(name=__file__)
83
-
84
- # extract global buffers
85
- buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop == UOps.DEFINE_GLOBAL}
86
- buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
87
-
88
- # create llvm function
89
- func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
90
- func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
91
- for a in func.args:
92
- if a.type.is_pointer: a.add_attribute("noalias")
93
-
94
- # add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
95
- func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
96
- func.attributes.add('"no-nans-fp-math"="true"')
97
-
98
- bb = [ir.IRBuilder(func.append_basic_block("entry"))]
99
- loop_blocks: List = []
100
- reduce_phis: List = []
101
- # TODO: newvar probably shouldn't be optional
102
- lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
103
-
104
- for bufname,dtype in buf_to_dtype.items():
105
- if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
106
-
107
- for u in uops:
108
- uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
109
- if uop == UOps.LOOP:
110
- bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
111
- bb[-2].branch(bb[-1]._block)
112
-
113
- phis = []
114
- for rp in reduce_phis:
115
- incoming = lvars[rp]
116
- lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
117
- lvars[rp].add_incoming(incoming, bb[-2]._block)
118
- phis.append((rp, lvars[rp]))
119
-
120
- lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
121
- lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block)
122
- loop_blocks.append((bb[-1], phis))
123
- if uop == UOps.END:
124
- block, phis = loop_blocks.pop()
125
- idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
126
- lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block)
127
- for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
128
- bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
129
- bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
130
- if uop == UOps.DEFINE_GLOBAL: lvars[u] = func.args[buf_index[args]]
131
- if uop == UOps.DEFINE_ACC:
132
- lvars[u] = const(args, dtype)
133
- reduce_phis.append(u)
134
- if uop == UOps.SPECIAL: lvars[u] = lvars[args.expr]
135
- if uop == UOps.CONST: lvars[u] = const(args, dtype)
136
- if uop == UOps.LOAD:
137
- assert dtype is not None
138
- if len(vin) > 2:
139
- gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1))
140
- aug_idx = bb[-1].select(gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
141
- val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
142
- val = cast(bb, val, vin[0].dtype, dtype)
143
- val = bb[-1].select(gate, val, lvars[vin[3]])
66
+ def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
67
+
68
+ class LLVMRenderer(Renderer):
69
+ device = "LLVM"
70
+ supports_float4 = False
71
+ has_local = False
72
+ has_shared = False
73
+ global_max = None
74
+
75
+ def render(self, name:str, uops:UOpGraph) -> str:
76
+ # all llvm stuff goes into a module
77
+ module = ir.Module(name=__file__)
78
+
79
+ # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
80
+ buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
81
+ buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
82
+
83
+ # create llvm function
84
+ func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
85
+ func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
86
+ for a in func.args:
87
+ if a.type.is_pointer: a.add_attribute("noalias")
88
+
89
+ # add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
90
+ func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
91
+ func.attributes.add('"no-nans-fp-math"="true"')
92
+
93
+ bb = [ir.IRBuilder(func.append_basic_block("entry"))]
94
+ loop_blocks: List = []
95
+ reduce_phis: List = []
96
+ # TODO: newvar probably shouldn't be optional
97
+ lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
98
+
99
+ for bufname,dtype in buf_to_dtype.items():
100
+ if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
101
+
102
+ for u in uops:
103
+ uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
104
+ if uop is UOps.STORE:
105
+ element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
106
+ if len(src) > 3:
107
+ with bb[-1].if_then(lvars[src[3]]):
108
+ bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
109
+ else:
110
+ bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
111
+ elif uop is UOps.ENDRANGE:
112
+ loop_entry_bb, phis = loop_blocks.pop()
113
+ idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
114
+ lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
115
+ for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
116
+ bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
117
+ bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
144
118
  else:
145
- val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
146
- val = cast(bb, val, vin[0].dtype, dtype)
147
- lvars[u] = val
148
- if uop == UOps.PHI:
149
- lvars[u] = lvars[vin[1]]
150
- # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
151
- backward = vin[0]
152
- while backward.uop == UOps.PHI: backward = backward.vin[0]
153
- lvars[backward] = lvars[u]
154
- if uop == UOps.STORE:
155
- element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
156
- def store_op(): bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
157
- if len(vin) > 3:
158
- with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op()
159
- else: store_op()
160
- if uop == UOps.ALU:
161
- lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype])
162
- if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
163
-
164
- bb[-1].ret_void()
165
- return str(module)
119
+ assert dtype is not None, f"None dtype for uop {uop}"
120
+ if uop is UOps.RANGE:
121
+ bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
122
+ bb[-2].branch(bb[-1].block)
123
+
124
+ phis = []
125
+ for rp in reduce_phis:
126
+ incoming = lvars[rp]
127
+ lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
128
+ lvars[rp].add_incoming(incoming, bb[-2].block)
129
+ phis.append((rp, lvars[rp]))
130
+
131
+ lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
132
+ lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
133
+ loop_blocks.append((bb[-1].block, phis))
134
+ elif uop is UOps.DEFINE_ACC:
135
+ lvars[u] = const(src[0].arg, dtype)
136
+ reduce_phis.append(u)
137
+ elif uop is UOps.LOAD:
138
+ if len(src) > 2:
139
+ aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
140
+ val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
141
+ val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
142
+ else:
143
+ val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
144
+ lvars[u] = val
145
+ elif uop is UOps.PHI:
146
+ lvars[u] = lvars[src[1]]
147
+ # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
148
+ backward = src[0]
149
+ while backward.op is UOps.PHI: backward = backward.src[0]
150
+ lvars[backward] = lvars[u]
151
+ elif uop is UOps.ALU:
152
+ lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
153
+ elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
154
+ elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
155
+ elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
156
+ elif uop is UOps.CONST: lvars[u] = const(args, dtype)
157
+ else: raise RuntimeError(f"failed to render {uop}")
158
+
159
+ bb[-1].ret_void()
160
+ return str(module)
File without changes