tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/renderer/llvmir.py
CHANGED
@@ -1,148 +1,160 @@
|
|
1
|
-
from typing import Final, Dict, Callable, Any, List, Optional
|
2
|
-
import
|
3
|
-
from
|
4
|
-
from tinygrad.
|
5
|
-
from tinygrad.helpers import dtypes
|
1
|
+
from typing import Final, Dict, Callable, Any, List, Optional
|
2
|
+
from llvmlite import ir
|
3
|
+
from tinygrad.codegen.linearizer import UOps, UOp
|
4
|
+
from tinygrad.dtype import DType, PtrDType, dtypes
|
6
5
|
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
6
|
+
from tinygrad.codegen.uops import UOpGraph
|
7
|
+
from tinygrad.renderer import Renderer
|
7
8
|
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
NumNode: lambda self,ops,ctx: int_const(self.b),
|
12
|
-
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)),
|
13
|
-
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)),
|
14
|
-
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)),
|
15
|
-
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)),
|
16
|
-
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
17
|
-
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
|
18
|
-
}
|
9
|
+
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
10
|
+
|
11
|
+
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
19
12
|
|
20
13
|
code_for_op: Final[Dict[Op, Callable]] = {
|
21
|
-
UnaryOps.
|
22
|
-
|
23
|
-
UnaryOps.
|
24
|
-
UnaryOps.
|
25
|
-
|
26
|
-
|
27
|
-
BinaryOps.
|
28
|
-
BinaryOps.
|
29
|
-
BinaryOps.
|
30
|
-
BinaryOps.
|
31
|
-
BinaryOps.
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
14
|
+
UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
|
15
|
+
(builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
|
16
|
+
UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
|
17
|
+
UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
|
18
|
+
UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
|
19
|
+
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
20
|
+
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
21
|
+
BinaryOps.SUB: lambda builder, x, y, dtype: builder.sub(x, y) if dtypes.is_int(dtype) else builder.fsub(x, y, flags=MFLAGS),
|
22
|
+
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
|
23
|
+
BinaryOps.DIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y) if dtypes.is_int(dtype) else builder.fdiv(x, y, flags=MFLAGS), # noqa: E501
|
24
|
+
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
25
|
+
BinaryOps.CMPEQ: lambda builder, x, y, dtype: builder.icmp_unsigned("==", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("==", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("==", x, y, flags=MFLAGS), # noqa: E501
|
26
|
+
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
27
|
+
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
|
28
|
+
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
|
29
|
+
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
|
30
|
+
|
31
|
+
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
|
32
|
+
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
|
33
|
+
dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
|
34
|
+
|
35
|
+
def cast(bb, val, input_type, output_type, bitcast=False):
|
39
36
|
if input_type == output_type: return val
|
40
|
-
|
41
|
-
if
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
return val
|
37
|
+
llvm_type = dtype_to_llvm_dtype[output_type]
|
38
|
+
if bitcast: return bb[-1].bitcast(val, llvm_type)
|
39
|
+
|
40
|
+
if input_type == dtypes.bfloat16:
|
41
|
+
val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
|
42
|
+
input_type = dtypes.float32
|
43
|
+
if output_type == dtypes.bfloat16:
|
44
|
+
val = cast(bb, val, input_type, dtypes.float32)
|
45
|
+
return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
|
46
|
+
|
47
|
+
if dtypes.is_float(input_type):
|
48
|
+
if dtypes.is_float(output_type):
|
49
|
+
return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
|
50
|
+
if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
|
51
|
+
if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
|
52
|
+
|
53
|
+
if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
|
54
|
+
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
|
55
|
+
if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
|
56
|
+
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
|
57
|
+
if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
|
58
|
+
|
59
|
+
if dtypes.is_int(input_type):
|
60
|
+
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
|
61
|
+
if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
|
62
|
+
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
|
63
|
+
if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
|
66
64
|
|
67
65
|
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
68
66
|
|
69
|
-
def
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
for var in args[0][::-1]:
|
112
|
-
if isinstance(var, NumNode): continue
|
113
|
-
block, phis = loop_blocks.pop()
|
114
|
-
idx_p1 = bb[-1].add(lvars[var.expr], int_const(1))
|
115
|
-
lvars[var.expr].add_incoming(idx_p1, bb[-1]._block)
|
116
|
-
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
|
117
|
-
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}")))
|
118
|
-
bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block)
|
119
|
-
if uop == UOps.LOAD:
|
120
|
-
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
|
121
|
-
valid = args.valid.render(render_llvm, bb[-1])
|
122
|
-
if isinstance(args, ConstOp):
|
123
|
-
value, invalid_value = [int(args.value), int(args.invalid_value)] if dtypes.is_int(newvar.dtype) else ([bool(args.value), bool(args.invalid_value)] if newvar.dtype == dtypes.bool else [args.value, args.invalid_value]) # type: ignore
|
124
|
-
if args.valid.min == 0 and args.valid.max == 1:
|
125
|
-
val = bb[-1].select(valid, ir.Constant(dtype_to_llvm_dtype[newvar.dtype], value), ir.Constant(dtype_to_llvm_dtype[newvar.dtype], invalid_value))
|
67
|
+
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
|
68
|
+
|
69
|
+
class LLVMRenderer(Renderer):
|
70
|
+
device = "LLVM"
|
71
|
+
supports_float4=False
|
72
|
+
has_local=False
|
73
|
+
has_shared=False
|
74
|
+
|
75
|
+
def render(self, name:str, uops:UOpGraph) -> str:
|
76
|
+
# all llvm stuff goes into a module
|
77
|
+
module = ir.Module(name=__file__)
|
78
|
+
|
79
|
+
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
80
|
+
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
81
|
+
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
82
|
+
|
83
|
+
# create llvm function
|
84
|
+
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
|
85
|
+
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
|
86
|
+
for a in func.args:
|
87
|
+
if a.type.is_pointer: a.add_attribute("noalias")
|
88
|
+
|
89
|
+
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
|
90
|
+
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
91
|
+
func.attributes.add('"no-nans-fp-math"="true"')
|
92
|
+
|
93
|
+
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
94
|
+
loop_blocks: List = []
|
95
|
+
reduce_phis: List = []
|
96
|
+
# TODO: newvar probably shouldn't be optional
|
97
|
+
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
|
98
|
+
|
99
|
+
for bufname,dtype in buf_to_dtype.items():
|
100
|
+
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
101
|
+
|
102
|
+
for u in uops:
|
103
|
+
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
104
|
+
if uop is UOps.STORE:
|
105
|
+
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
|
106
|
+
if len(vin) > 3:
|
107
|
+
with bb[-1].if_then(lvars[vin[3]]):
|
108
|
+
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
126
109
|
else:
|
127
|
-
|
128
|
-
|
129
|
-
|
110
|
+
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
111
|
+
elif uop is UOps.ENDRANGE:
|
112
|
+
loop_entry_bb, phis = loop_blocks.pop()
|
113
|
+
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
|
114
|
+
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
|
115
|
+
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
116
|
+
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
117
|
+
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
|
130
118
|
else:
|
131
|
-
|
132
|
-
if
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
119
|
+
assert dtype is not None, f"None dtype for uop {uop}"
|
120
|
+
if uop is UOps.RANGE:
|
121
|
+
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
122
|
+
bb[-2].branch(bb[-1].block)
|
123
|
+
|
124
|
+
phis = []
|
125
|
+
for rp in reduce_phis:
|
126
|
+
incoming = lvars[rp]
|
127
|
+
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
|
128
|
+
lvars[rp].add_incoming(incoming, bb[-2].block)
|
129
|
+
phis.append((rp, lvars[rp]))
|
130
|
+
|
131
|
+
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
132
|
+
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
|
133
|
+
loop_blocks.append((bb[-1].block, phis))
|
134
|
+
elif uop is UOps.DEFINE_ACC:
|
135
|
+
lvars[u] = const(args[0], dtype)
|
136
|
+
reduce_phis.append(u)
|
137
|
+
elif uop is UOps.LOAD:
|
138
|
+
if len(vin) > 2:
|
139
|
+
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
|
140
|
+
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
|
141
|
+
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
|
142
|
+
else:
|
143
|
+
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
144
|
+
lvars[u] = val
|
145
|
+
elif uop is UOps.PHI:
|
146
|
+
lvars[u] = lvars[vin[1]]
|
147
|
+
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
148
|
+
backward = vin[0]
|
149
|
+
while backward.uop is UOps.PHI: backward = backward.vin[0]
|
150
|
+
lvars[backward] = lvars[u]
|
151
|
+
elif uop is UOps.ALU:
|
152
|
+
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
|
153
|
+
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
154
|
+
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
155
|
+
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
156
|
+
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
157
|
+
else: raise RuntimeError(f"failed to render {uop}")
|
158
|
+
|
159
|
+
bb[-1].ret_void()
|
160
|
+
return str(module)
|