tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/renderer/llvmir.py
CHANGED
@@ -1,35 +1,31 @@
|
|
1
1
|
from typing import Final, Dict, Callable, Any, List, Optional
|
2
2
|
from llvmlite import ir
|
3
|
-
from tinygrad.codegen.linearizer import UOps, UOp
|
4
3
|
from tinygrad.dtype import DType, PtrDType, dtypes
|
5
4
|
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
5
|
+
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
6
|
+
from tinygrad.renderer import Renderer
|
6
7
|
|
7
8
|
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
8
9
|
|
9
10
|
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
10
11
|
|
11
12
|
code_for_op: Final[Dict[Op, Callable]] = {
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y),
|
29
|
-
TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS) \
|
30
|
-
if dtypes.is_float(var_dtype) else builder.add(builder.mul(x, y), z),
|
31
|
-
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z),
|
32
|
-
}
|
13
|
+
UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
|
14
|
+
(builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
|
15
|
+
UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
|
16
|
+
UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
|
17
|
+
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
|
18
|
+
UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
|
19
|
+
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
20
|
+
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
21
|
+
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
|
22
|
+
BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
|
23
|
+
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
24
|
+
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
|
25
|
+
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
26
|
+
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
|
27
|
+
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
|
28
|
+
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
|
33
29
|
|
34
30
|
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
|
35
31
|
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
|
@@ -37,7 +33,8 @@ dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dt
|
|
37
33
|
|
38
34
|
def cast(bb, val, input_type, output_type, bitcast=False):
|
39
35
|
if input_type == output_type: return val
|
40
|
-
|
36
|
+
llvm_type = dtype_to_llvm_dtype[output_type]
|
37
|
+
if bitcast: return bb[-1].bitcast(val, llvm_type)
|
41
38
|
|
42
39
|
if input_type == dtypes.bfloat16:
|
43
40
|
val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
|
@@ -48,118 +45,116 @@ def cast(bb, val, input_type, output_type, bitcast=False):
|
|
48
45
|
|
49
46
|
if dtypes.is_float(input_type):
|
50
47
|
if dtypes.is_float(output_type):
|
51
|
-
if output_type.itemsize > input_type.itemsize
|
52
|
-
|
53
|
-
if dtypes.is_int(output_type):
|
54
|
-
if dtypes.is_unsigned(output_type): return bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type])
|
55
|
-
return bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
|
48
|
+
return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
|
49
|
+
if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
|
56
50
|
if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
|
57
51
|
|
58
52
|
if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
|
59
53
|
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
|
60
54
|
if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
|
61
|
-
if dtypes.is_int(output_type):
|
62
|
-
if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type])
|
63
|
-
return bb[-1].zext(val, dtype_to_llvm_dtype[output_type])
|
55
|
+
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
|
64
56
|
if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
|
65
57
|
|
66
58
|
if dtypes.is_int(input_type):
|
67
59
|
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
|
68
|
-
if dtypes.is_float(output_type): return bb[-1].sitofp(val,
|
69
|
-
if dtypes.is_int(output_type):
|
70
|
-
if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type])
|
71
|
-
return bb[-1].sext(val, dtype_to_llvm_dtype[output_type])
|
60
|
+
if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
|
61
|
+
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
|
72
62
|
if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
|
73
63
|
|
74
64
|
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
75
65
|
|
76
|
-
def const(args, dtype):
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
bb.
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
129
|
-
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
|
130
|
-
if uop == UOps.DEFINE_GLOBAL: lvars[u] = func.args[buf_index[args]]
|
131
|
-
if uop == UOps.DEFINE_ACC:
|
132
|
-
lvars[u] = const(args, dtype)
|
133
|
-
reduce_phis.append(u)
|
134
|
-
if uop == UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
135
|
-
if uop == UOps.CONST: lvars[u] = const(args, dtype)
|
136
|
-
if uop == UOps.LOAD:
|
137
|
-
assert dtype is not None
|
138
|
-
if len(vin) > 2:
|
139
|
-
gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1))
|
140
|
-
aug_idx = bb[-1].select(gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
|
141
|
-
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
|
142
|
-
val = cast(bb, val, vin[0].dtype, dtype)
|
143
|
-
val = bb[-1].select(gate, val, lvars[vin[3]])
|
66
|
+
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
|
67
|
+
|
68
|
+
class LLVMRenderer(Renderer):
|
69
|
+
device = "LLVM"
|
70
|
+
supports_float4 = False
|
71
|
+
has_local = False
|
72
|
+
has_shared = False
|
73
|
+
global_max = None
|
74
|
+
|
75
|
+
def render(self, name:str, uops:UOpGraph) -> str:
|
76
|
+
# all llvm stuff goes into a module
|
77
|
+
module = ir.Module(name=__file__)
|
78
|
+
|
79
|
+
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
80
|
+
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
81
|
+
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
82
|
+
|
83
|
+
# create llvm function
|
84
|
+
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
|
85
|
+
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
|
86
|
+
for a in func.args:
|
87
|
+
if a.type.is_pointer: a.add_attribute("noalias")
|
88
|
+
|
89
|
+
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
|
90
|
+
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
91
|
+
func.attributes.add('"no-nans-fp-math"="true"')
|
92
|
+
|
93
|
+
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
94
|
+
loop_blocks: List = []
|
95
|
+
reduce_phis: List = []
|
96
|
+
# TODO: newvar probably shouldn't be optional
|
97
|
+
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
|
98
|
+
|
99
|
+
for bufname,dtype in buf_to_dtype.items():
|
100
|
+
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
101
|
+
|
102
|
+
for u in uops:
|
103
|
+
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
104
|
+
if uop is UOps.STORE:
|
105
|
+
element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
|
106
|
+
if len(src) > 3:
|
107
|
+
with bb[-1].if_then(lvars[src[3]]):
|
108
|
+
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
109
|
+
else:
|
110
|
+
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
111
|
+
elif uop is UOps.ENDRANGE:
|
112
|
+
loop_entry_bb, phis = loop_blocks.pop()
|
113
|
+
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
|
114
|
+
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
|
115
|
+
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
116
|
+
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
117
|
+
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
|
144
118
|
else:
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
119
|
+
assert dtype is not None, f"None dtype for uop {uop}"
|
120
|
+
if uop is UOps.RANGE:
|
121
|
+
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
122
|
+
bb[-2].branch(bb[-1].block)
|
123
|
+
|
124
|
+
phis = []
|
125
|
+
for rp in reduce_phis:
|
126
|
+
incoming = lvars[rp]
|
127
|
+
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
|
128
|
+
lvars[rp].add_incoming(incoming, bb[-2].block)
|
129
|
+
phis.append((rp, lvars[rp]))
|
130
|
+
|
131
|
+
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
132
|
+
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
|
133
|
+
loop_blocks.append((bb[-1].block, phis))
|
134
|
+
elif uop is UOps.DEFINE_ACC:
|
135
|
+
lvars[u] = const(src[0].arg, dtype)
|
136
|
+
reduce_phis.append(u)
|
137
|
+
elif uop is UOps.LOAD:
|
138
|
+
if len(src) > 2:
|
139
|
+
aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
|
140
|
+
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
|
141
|
+
val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
|
142
|
+
else:
|
143
|
+
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
144
|
+
lvars[u] = val
|
145
|
+
elif uop is UOps.PHI:
|
146
|
+
lvars[u] = lvars[src[1]]
|
147
|
+
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
148
|
+
backward = src[0]
|
149
|
+
while backward.op is UOps.PHI: backward = backward.src[0]
|
150
|
+
lvars[backward] = lvars[u]
|
151
|
+
elif uop is UOps.ALU:
|
152
|
+
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
|
153
|
+
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
154
|
+
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
155
|
+
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
156
|
+
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
157
|
+
else: raise RuntimeError(f"failed to render {uop}")
|
158
|
+
|
159
|
+
bb[-1].ret_void()
|
160
|
+
return str(module)
|
File without changes
|