tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +78 -90
- tinygrad/codegen/linearizer.py +237 -169
- tinygrad/codegen/uops.py +278 -242
- tinygrad/device.py +147 -10
- tinygrad/dtype.py +7 -7
- tinygrad/engine/graph.py +16 -16
- tinygrad/engine/jit.py +39 -36
- tinygrad/engine/realize.py +6 -5
- tinygrad/engine/schedule.py +15 -7
- tinygrad/engine/search.py +6 -3
- tinygrad/function.py +17 -23
- tinygrad/helpers.py +77 -8
- tinygrad/lazy.py +26 -26
- tinygrad/multi.py +13 -9
- tinygrad/nn/__init__.py +1 -1
- tinygrad/nn/datasets.py +2 -1
- tinygrad/nn/state.py +3 -4
- tinygrad/ops.py +49 -16
- tinygrad/renderer/__init__.py +8 -4
- tinygrad/renderer/assembly.py +93 -100
- tinygrad/renderer/cstyle.py +47 -42
- tinygrad/renderer/llvmir.py +30 -30
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +11504 -1
- tinygrad/runtime/autogen/comgr.py +36 -10
- tinygrad/runtime/autogen/hsa.py +146 -14
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/nv_gpu.py +269 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +20 -11
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +3 -2
- tinygrad/runtime/graph/cuda.py +2 -2
- tinygrad/runtime/graph/hcq.py +122 -78
- tinygrad/runtime/ops_amd.py +302 -316
- tinygrad/runtime/ops_cuda.py +3 -3
- tinygrad/runtime/ops_disk.py +70 -5
- tinygrad/runtime/ops_gpu.py +2 -2
- tinygrad/runtime/ops_metal.py +5 -6
- tinygrad/runtime/ops_npy.py +1 -1
- tinygrad/runtime/ops_nv.py +161 -166
- tinygrad/runtime/ops_python.py +20 -16
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +5 -2
- tinygrad/shape/symbolic.py +1 -3
- tinygrad/shape/view.py +34 -19
- tinygrad/tensor.py +219 -135
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/runtime/driver/hsa.py +0 -143
- tinygrad/runtime/graph/hsa.py +0 -171
- tinygrad/runtime/ops_hsa.py +0 -278
- tinygrad-0.9.0.dist-info/RECORD +0 -60
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/renderer/llvmir.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
from typing import Final, Dict, Callable, Any, List, Optional
|
2
2
|
from llvmlite import ir
|
3
|
-
from tinygrad.codegen.linearizer import UOps, UOp
|
4
3
|
from tinygrad.dtype import DType, PtrDType, dtypes
|
5
4
|
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
6
|
-
from tinygrad.codegen.uops import UOpGraph
|
5
|
+
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
7
6
|
from tinygrad.renderer import Renderer
|
8
7
|
|
9
8
|
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
@@ -15,14 +14,14 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
|
15
14
|
(builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
|
16
15
|
UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
|
17
16
|
UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
|
17
|
+
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
|
18
18
|
UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
|
19
19
|
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
20
20
|
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
21
|
-
BinaryOps.SUB: lambda builder, x, y, dtype: builder.sub(x, y) if dtypes.is_int(dtype) else builder.fsub(x, y, flags=MFLAGS),
|
22
21
|
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
|
23
|
-
BinaryOps.
|
22
|
+
BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
|
24
23
|
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
25
|
-
BinaryOps.
|
24
|
+
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
|
26
25
|
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
27
26
|
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
|
28
27
|
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
|
@@ -68,16 +67,17 @@ def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
|
|
68
67
|
|
69
68
|
class LLVMRenderer(Renderer):
|
70
69
|
device = "LLVM"
|
71
|
-
supports_float4=False
|
72
|
-
has_local=False
|
73
|
-
has_shared=False
|
70
|
+
supports_float4 = False
|
71
|
+
has_local = False
|
72
|
+
has_shared = False
|
73
|
+
global_max = None
|
74
74
|
|
75
75
|
def render(self, name:str, uops:UOpGraph) -> str:
|
76
76
|
# all llvm stuff goes into a module
|
77
77
|
module = ir.Module(name=__file__)
|
78
78
|
|
79
79
|
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
80
|
-
buf_to_dtype = {u.arg:u.dtype for u in uops if u.
|
80
|
+
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
81
81
|
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
82
82
|
|
83
83
|
# create llvm function
|
@@ -100,21 +100,21 @@ class LLVMRenderer(Renderer):
|
|
100
100
|
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
101
101
|
|
102
102
|
for u in uops:
|
103
|
-
uop,dtype,
|
103
|
+
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
104
104
|
if uop is UOps.STORE:
|
105
|
-
element = cast(bb, lvars[
|
106
|
-
if len(
|
107
|
-
with bb[-1].if_then(lvars[
|
108
|
-
bb[-1].store(element, bb[-1].gep(lvars[
|
105
|
+
element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
|
106
|
+
if len(src) > 3:
|
107
|
+
with bb[-1].if_then(lvars[src[3]]):
|
108
|
+
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
109
109
|
else:
|
110
|
-
bb[-1].store(element, bb[-1].gep(lvars[
|
110
|
+
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
111
111
|
elif uop is UOps.ENDRANGE:
|
112
112
|
loop_entry_bb, phis = loop_blocks.pop()
|
113
|
-
idx_p1 = bb[-1].add(lvars[
|
114
|
-
lvars[
|
113
|
+
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
|
114
|
+
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
|
115
115
|
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
116
116
|
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
117
|
-
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[
|
117
|
+
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
|
118
118
|
else:
|
119
119
|
assert dtype is not None, f"None dtype for uop {uop}"
|
120
120
|
if uop is UOps.RANGE:
|
@@ -129,28 +129,28 @@ class LLVMRenderer(Renderer):
|
|
129
129
|
phis.append((rp, lvars[rp]))
|
130
130
|
|
131
131
|
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
132
|
-
lvars[u].add_incoming(lvars[
|
132
|
+
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
|
133
133
|
loop_blocks.append((bb[-1].block, phis))
|
134
134
|
elif uop is UOps.DEFINE_ACC:
|
135
|
-
lvars[u] = const(
|
135
|
+
lvars[u] = const(src[0].arg, dtype)
|
136
136
|
reduce_phis.append(u)
|
137
137
|
elif uop is UOps.LOAD:
|
138
|
-
if len(
|
139
|
-
aug_idx = bb[-1].select(lvars[
|
140
|
-
val = bb[-1].load(bb[-1].gep(lvars[
|
141
|
-
val = bb[-1].select(lvars[
|
138
|
+
if len(src) > 2:
|
139
|
+
aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
|
140
|
+
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
|
141
|
+
val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
|
142
142
|
else:
|
143
|
-
val = bb[-1].load(bb[-1].gep(lvars[
|
143
|
+
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
144
144
|
lvars[u] = val
|
145
145
|
elif uop is UOps.PHI:
|
146
|
-
lvars[u] = lvars[
|
146
|
+
lvars[u] = lvars[src[1]]
|
147
147
|
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
148
|
-
backward =
|
149
|
-
while backward.
|
148
|
+
backward = src[0]
|
149
|
+
while backward.op is UOps.PHI: backward = backward.src[0]
|
150
150
|
lvars[backward] = lvars[u]
|
151
151
|
elif uop is UOps.ALU:
|
152
|
-
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in
|
153
|
-
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[
|
152
|
+
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
|
153
|
+
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
154
154
|
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
155
155
|
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
156
156
|
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
File without changes
|