tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/renderer/assembly.py
DELETED
@@ -1,269 +0,0 @@
|
|
1
|
-
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
2
|
-
import struct, math
|
3
|
-
from collections import defaultdict
|
4
|
-
from tinygrad.helpers import DEBUG
|
5
|
-
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
6
|
-
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
7
|
-
from tinygrad.codegen.uops import UOps, UOp, UOpGraph, PatternMatcher, UPat
|
8
|
-
from tinygrad.renderer import Renderer, TensorCore
|
9
|
-
|
10
|
-
def render_val(x, dtype):
|
11
|
-
if dtypes.is_float(dtype):
|
12
|
-
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
13
|
-
if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
|
14
|
-
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
15
|
-
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
16
|
-
|
17
|
-
class PTXRenderer(Renderer):
|
18
|
-
device = "CUDA"
|
19
|
-
suffix = "PTX"
|
20
|
-
global_max = (2147483647, 65535, 65535)
|
21
|
-
local_max = (1024, 1024, 64)
|
22
|
-
shared_max = 49152
|
23
|
-
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
|
24
|
-
def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
25
|
-
|
26
|
-
# language options
|
27
|
-
kernel_prefix = """.version VERSION
|
28
|
-
.target TARGET
|
29
|
-
.address_size 64
|
30
|
-
.visible .entry"""
|
31
|
-
barrier = "bar.sync\t0;"
|
32
|
-
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
|
33
|
-
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
34
|
-
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
35
|
-
asm_for_op: Dict[Op, Callable] = {
|
36
|
-
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) \
|
37
|
-
else f"neg.{name} {d}, {a};",
|
38
|
-
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
39
|
-
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
40
|
-
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
41
|
-
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
42
|
-
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
43
|
-
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
44
|
-
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
45
|
-
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
46
|
-
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
47
|
-
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
48
|
-
BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
49
|
-
TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
50
|
-
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
51
|
-
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
52
|
-
}
|
53
|
-
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
54
|
-
TernaryOps.WHERE]
|
55
|
-
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
|
56
|
-
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
57
|
-
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
58
|
-
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
|
59
|
-
|
60
|
-
mem_types: Dict[DType, str] = types.copy()
|
61
|
-
mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
|
62
|
-
|
63
|
-
const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
|
64
|
-
|
65
|
-
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
|
66
|
-
val = render_val(x, dtype)
|
67
|
-
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
|
68
|
-
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
|
69
|
-
|
70
|
-
def render_local(self, dest, name, size, dtype) -> List[str]:
|
71
|
-
return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
|
72
|
-
|
73
|
-
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
|
74
|
-
|
75
|
-
def render_bra(self, b1, pred=None) -> List[str]: return [f"@{pred} bra {b1};"] if pred else [f"bra {b1};"]
|
76
|
-
|
77
|
-
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
|
78
|
-
assert dtype != dtypes.bool
|
79
|
-
if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
|
80
|
-
return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"]
|
81
|
-
|
82
|
-
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
|
83
|
-
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_types[dtype]} [{loc}+{offset}], {val};"]
|
84
|
-
|
85
|
-
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
|
86
|
-
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
|
87
|
-
if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
|
88
|
-
if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
|
89
|
-
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
|
90
|
-
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
|
91
|
-
return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
|
92
|
-
|
93
|
-
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
|
94
|
-
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
|
95
|
-
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
96
|
-
return (f"{self.kernel_prefix} {function_name}(\n\t" +
|
97
|
-
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
|
98
|
-
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
|
99
|
-
"\n}")
|
100
|
-
|
101
|
-
def render(self, name:str, uops:UOpGraph) -> str:
|
102
|
-
kernel:List[str] = []
|
103
|
-
bufs = []
|
104
|
-
|
105
|
-
uops.linearize(ptx_matcher)
|
106
|
-
if DEBUG >= 4: uops.print()
|
107
|
-
|
108
|
-
def kk(*s: str): kernel.append("\n".join(s))
|
109
|
-
|
110
|
-
c: DefaultDict[str, int] = defaultdict(int)
|
111
|
-
r: Dict[UOp, Union[List[str], str]] = {}
|
112
|
-
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
|
113
|
-
nonlocal c, r
|
114
|
-
prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
|
115
|
-
c[prefix] += 1
|
116
|
-
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
117
|
-
return f"%{prefix}{c[prefix]-1}"
|
118
|
-
|
119
|
-
def const(x:ConstType, dtype:DType, mov=False):
|
120
|
-
if mov or dtype in self.const_requires_mov:
|
121
|
-
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
|
122
|
-
return out
|
123
|
-
return self.render_const(x, dtype)
|
124
|
-
|
125
|
-
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
126
|
-
if atype == dtype or isinstance(atype, PtrDType):
|
127
|
-
if u: r[u] = a
|
128
|
-
return a
|
129
|
-
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
|
130
|
-
return ret
|
131
|
-
|
132
|
-
for u in uops:
|
133
|
-
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
134
|
-
if uop is UOps.IF:
|
135
|
-
assert src[0].dtype is not None
|
136
|
-
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
|
137
|
-
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
138
|
-
elif uop is UOps.ENDRANGE:
|
139
|
-
kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
|
140
|
-
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
141
|
-
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
|
142
|
-
elif uop is UOps.ENDIF:
|
143
|
-
kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
|
144
|
-
elif uop is UOps.STORE:
|
145
|
-
assert src[0].dtype is not None and src[2].dtype is not None
|
146
|
-
assert src[0].dtype == dtypes.int64, "store isn't int64"
|
147
|
-
assert src[1].op is UOps.CONST, f"store isn't const {u}"
|
148
|
-
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
149
|
-
if src[2].dtype.count > 1:
|
150
|
-
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
|
151
|
-
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
|
152
|
-
else:
|
153
|
-
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
|
154
|
-
else:
|
155
|
-
assert dtype is not None, f"None dtype for uop {uop}"
|
156
|
-
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
157
|
-
elif uop is UOps.ALU:
|
158
|
-
assert src[0].dtype is not None
|
159
|
-
if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
|
160
|
-
# pass in the other dtype here
|
161
|
-
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
|
162
|
-
else:
|
163
|
-
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
|
164
|
-
elif uop is UOps.DEFINE_ACC:
|
165
|
-
if dtype.count > 1:
|
166
|
-
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
167
|
-
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
|
168
|
-
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
169
|
-
elif uop is UOps.SPECIAL:
|
170
|
-
assert args[1][0] != "i", "idx not supported"
|
171
|
-
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
|
172
|
-
r[u] = "%" + args[1]
|
173
|
-
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
174
|
-
elif uop is UOps.CONST:
|
175
|
-
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
176
|
-
else: r[u] = const(args, dtype, mov=True)
|
177
|
-
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
|
178
|
-
elif uop is UOps.LOAD:
|
179
|
-
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
180
|
-
assert src[1].op is UOps.CONST, f"load isn't const {u}"
|
181
|
-
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
182
|
-
if dtype.count > 1:
|
183
|
-
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
184
|
-
if(len(src)>3):
|
185
|
-
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
|
186
|
-
kk((f"@{r[src[2]]}"if len(src) > 3 else "")
|
187
|
-
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
|
188
|
-
else:
|
189
|
-
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
|
190
|
-
alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
|
191
|
-
elif uop is UOps.PHI:
|
192
|
-
if dtype.count > 1:
|
193
|
-
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
194
|
-
else:
|
195
|
-
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
|
196
|
-
r[u] = r[src[0]]
|
197
|
-
elif uop in {UOps.CAST, UOps.BITCAST}:
|
198
|
-
assert src[0].dtype is not None
|
199
|
-
if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
|
200
|
-
else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
201
|
-
elif uop is UOps.DEFINE_LOCAL:
|
202
|
-
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
203
|
-
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
204
|
-
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
|
205
|
-
elif uop is UOps.DEFINE_VAR:
|
206
|
-
bufs.append((args.expr, dtype))
|
207
|
-
r[u] = f"%{args.expr}"
|
208
|
-
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
209
|
-
elif uop is UOps.DEFINE_GLOBAL:
|
210
|
-
bufs.append((nm:=f"data{args[0]}", dtype))
|
211
|
-
r[u] = f"%{nm}"
|
212
|
-
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
213
|
-
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
214
|
-
elif uop is UOps.WMMA:
|
215
|
-
wmma = []
|
216
|
-
for vv in src[:2]:
|
217
|
-
for i in range(0, len(r[vv]), 2):
|
218
|
-
wmma.append(ssa("wmma", dtype="b32"))
|
219
|
-
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
220
|
-
r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
221
|
-
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
222
|
-
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};')
|
223
|
-
else: raise NotImplementedError(f"no code for {uop}")
|
224
|
-
|
225
|
-
return self.render_kernel(kernel, name, bufs, c.items())
|
226
|
-
|
227
|
-
ptx_matcher = PatternMatcher([
|
228
|
-
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
229
|
-
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
|
230
|
-
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL)),
|
231
|
-
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
232
|
-
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
|
233
|
-
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR)),
|
234
|
-
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
235
|
-
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
236
|
-
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
237
|
-
(UPat(UOps.ALU, BinaryOps.ADD,
|
238
|
-
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
239
|
-
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
240
|
-
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
241
|
-
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
|
242
|
-
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
|
243
|
-
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
244
|
-
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
245
|
-
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
246
|
-
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
|
247
|
-
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
248
|
-
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
249
|
-
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
250
|
-
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
251
|
-
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
252
|
-
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
253
|
-
# ptr_ar (load/store)
|
254
|
-
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
255
|
-
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
256
|
-
lambda root, alu, const: UOp(root.op, root.dtype,
|
257
|
-
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
258
|
-
UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
|
259
|
-
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
260
|
-
UPat(UOps.CONST, name="const"))),
|
261
|
-
lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
|
262
|
-
UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
|
263
|
-
)+root.src[2:])),
|
264
|
-
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
265
|
-
UPat(name="alu"))), # no const here
|
266
|
-
lambda root, alu: UOp(root.op, root.dtype,
|
267
|
-
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
268
|
-
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
269
|
-
])
|
tinygrad/shape/symbolic.py
DELETED
@@ -1,327 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
import functools
|
3
|
-
from math import gcd
|
4
|
-
from tinygrad.helpers import partition
|
5
|
-
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
|
6
|
-
|
7
|
-
# NOTE: Python has different behavior for negative mod and floor div than c
|
8
|
-
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
9
|
-
|
10
|
-
class Node:
|
11
|
-
b: Union[Node, int]
|
12
|
-
min: int
|
13
|
-
max: sint
|
14
|
-
def render(self, ops=None, ctx=None) -> Any:
|
15
|
-
if ops is None: ops = render_python
|
16
|
-
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
17
|
-
return ops[type(self)](self, ops, ctx)
|
18
|
-
def vars(self) -> Set[Variable]: return set()
|
19
|
-
# substitute Variables with the values in var_vals
|
20
|
-
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
|
21
|
-
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
22
|
-
|
23
|
-
@functools.cached_property
|
24
|
-
def key(self) -> str: return self.render(ctx="DEBUG")
|
25
|
-
def __repr__(self): return self.render(ctx="REPR")
|
26
|
-
def __str__(self): return "<"+self.key+">"
|
27
|
-
def __hash__(self): return hash(self.key)
|
28
|
-
def __bool__(self): return not (self.max == self.min == 0)
|
29
|
-
def __eq__(self, other:object) -> bool:
|
30
|
-
if not isinstance(other, Node): return NotImplemented
|
31
|
-
return self.key == other.key
|
32
|
-
def __neg__(self): return self*-1
|
33
|
-
def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
|
34
|
-
def __radd__(self, b:int): return self+b
|
35
|
-
def __sub__(self, b:Union[Node,int]): return self+-b
|
36
|
-
def __rsub__(self, b:int): return -self+b
|
37
|
-
def __le__(self, b:Union[Node,int]): return self < (b+1)
|
38
|
-
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
39
|
-
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
40
|
-
def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
|
41
|
-
def __mul__(self, b:Union[Node, int]):
|
42
|
-
if b == 0: return NumNode(0)
|
43
|
-
if b == 1: return self
|
44
|
-
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
|
45
|
-
def __rmul__(self, b:int): return self*b
|
46
|
-
|
47
|
-
# *** complex ops ***
|
48
|
-
|
49
|
-
def __rfloordiv__(self, b:int): return NumNode(b) // self
|
50
|
-
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
51
|
-
if isinstance(b, Node):
|
52
|
-
if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
|
53
|
-
if self == b: return NumNode(1)
|
54
|
-
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
|
55
|
-
raise RuntimeError(f"not supported: {self} // {b}")
|
56
|
-
assert b != 0
|
57
|
-
if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
|
58
|
-
if b == 1: return self
|
59
|
-
|
60
|
-
# the numerator of div is not allowed to be negative
|
61
|
-
if self.min < 0:
|
62
|
-
offset = self.min//b
|
63
|
-
# factor out an "offset" to make the numerator positive. don't allowing factoring again
|
64
|
-
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
|
65
|
-
return create_node(DivNode(self, b))
|
66
|
-
|
67
|
-
def __rmod__(self, b:int): return NumNode(b) % self
|
68
|
-
def __mod__(self, b:Union[Node,int]):
|
69
|
-
if isinstance(b, Node):
|
70
|
-
if b.__class__ is NumNode: return self % b.b
|
71
|
-
if self == b: return NumNode(0)
|
72
|
-
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
|
73
|
-
raise RuntimeError(f"not supported: {self} % {b}")
|
74
|
-
assert b > 0
|
75
|
-
if b == 1: return NumNode(0)
|
76
|
-
if isinstance(self.max, int) and isinstance(self.min, int):
|
77
|
-
if self.min >= 0 and self.max < b: return self
|
78
|
-
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
|
79
|
-
if self.min < 0: return (self - ((self.min//b)*b)) % b
|
80
|
-
return create_node(ModNode(self, b))
|
81
|
-
|
82
|
-
@staticmethod
|
83
|
-
def sum(nodes:List[Node]) -> Node:
|
84
|
-
nodes = [x for x in nodes if x.max or x.min]
|
85
|
-
if not nodes: return NumNode(0)
|
86
|
-
if len(nodes) == 1: return nodes[0]
|
87
|
-
|
88
|
-
mul_groups: Dict[Node, int] = {}
|
89
|
-
num_node_sum = 0
|
90
|
-
for node in SumNode(nodes).flat_components:
|
91
|
-
if node.__class__ is NumNode: num_node_sum += node.b
|
92
|
-
elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b
|
93
|
-
else: mul_groups[node] = mul_groups.get(node, 0) + 1
|
94
|
-
new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
|
95
|
-
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
|
96
|
-
return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
|
97
|
-
|
98
|
-
@staticmethod
|
99
|
-
def ands(nodes:List[Node]) -> Node:
|
100
|
-
if not nodes: return NumNode(1)
|
101
|
-
if len(nodes) == 1: return nodes[0]
|
102
|
-
if any(not x for x in nodes): return NumNode(0)
|
103
|
-
|
104
|
-
# filter 1s
|
105
|
-
nodes = [x for x in nodes if x.min != x.max]
|
106
|
-
return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
|
107
|
-
|
108
|
-
# 4 basic node types
|
109
|
-
|
110
|
-
class Variable(Node):
|
111
|
-
def __new__(cls, *args):
|
112
|
-
expr, nmin, nmax = args
|
113
|
-
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
|
114
|
-
if nmin == nmax: return NumNode(nmin)
|
115
|
-
return super().__new__(cls)
|
116
|
-
|
117
|
-
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
|
118
|
-
|
119
|
-
def __init__(self, expr:str, nmin:int, nmax:sint):
|
120
|
-
self.expr, self.min, self.max = expr, nmin, nmax
|
121
|
-
self._val: Optional[int] = None
|
122
|
-
@property
|
123
|
-
def val(self):
|
124
|
-
assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
|
125
|
-
return self._val
|
126
|
-
def bind(self, val):
|
127
|
-
assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
|
128
|
-
self._val = val
|
129
|
-
return self
|
130
|
-
def unbind(self) -> Tuple[Variable, int]:
|
131
|
-
assert self.val is not None, f"cannot unbind {self}"
|
132
|
-
return Variable(self.expr, self.min, self.max), self.val
|
133
|
-
def vars(self): return {self}
|
134
|
-
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
|
135
|
-
|
136
|
-
class NumNode(Node):
|
137
|
-
def __init__(self, num:int):
|
138
|
-
assert isinstance(num, int), f"{num} is not an int"
|
139
|
-
self.b:int = num
|
140
|
-
self.min, self.max = num, num
|
141
|
-
def bind(self, val):
|
142
|
-
assert self.b == val, f"cannot bind {val} to {self}"
|
143
|
-
return self
|
144
|
-
def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
145
|
-
def __eq__(self, other): return self.b == other
|
146
|
-
def __hash__(self): return hash(self.b) # needed with __eq__ override
|
147
|
-
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
|
148
|
-
|
149
|
-
def create_node(ret:Node):
|
150
|
-
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
151
|
-
if ret.min == ret.max: return NumNode(ret.min)
|
152
|
-
return ret
|
153
|
-
|
154
|
-
def create_lt_node(lhs:Node, b:Union[Node, int]):
|
155
|
-
if isinstance(lhs, SumNode):
|
156
|
-
if isinstance(b, int):
|
157
|
-
new_sum = []
|
158
|
-
for x in lhs.nodes:
|
159
|
-
# TODO: should we just force the last one to always be the number
|
160
|
-
if isinstance(x, NumNode): b -= x.b
|
161
|
-
else: new_sum.append(x)
|
162
|
-
lhs = Node.sum(new_sum)
|
163
|
-
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
164
|
-
assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
|
165
|
-
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
166
|
-
if muls:
|
167
|
-
# NOTE: gcd in python 3.8 takes exactly 2 args
|
168
|
-
mul_gcd = b
|
169
|
-
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
|
170
|
-
all_others = Node.sum(others)
|
171
|
-
if all_others.min >= 0 and all_others.max < mul_gcd:
|
172
|
-
lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
173
|
-
return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
|
174
|
-
if isinstance(lhs, MulNode):
|
175
|
-
if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
|
176
|
-
sgn = 1 if lhs.b > 0 else -1
|
177
|
-
return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
|
178
|
-
return create_node(LtNode(lhs, b))
|
179
|
-
|
180
|
-
def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
|
181
|
-
|
182
|
-
class OpNode(Node):
|
183
|
-
def __init__(self, a:Node, b:Union[Node, int]):
|
184
|
-
self.a, self.b = a, b
|
185
|
-
self.min, self.max = self.get_bounds()
|
186
|
-
def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
|
187
|
-
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
188
|
-
|
189
|
-
class LtNode(OpNode):
|
190
|
-
def get_bounds(self) -> Tuple[int, int]:
|
191
|
-
if self.a == self.b: return (0, 0)
|
192
|
-
if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
193
|
-
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
|
194
|
-
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
195
|
-
return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
|
196
|
-
|
197
|
-
class MulNode(OpNode):
|
198
|
-
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
199
|
-
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
|
200
|
-
if self.b % b == 0: return self.a*(self.b//b)
|
201
|
-
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
202
|
-
return Node.__floordiv__(self, b, factoring_allowed)
|
203
|
-
def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
|
204
|
-
def get_bounds(self) -> Tuple[int, sint]:
|
205
|
-
assert self.a.min >= 0
|
206
|
-
if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
207
|
-
return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
|
208
|
-
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
209
|
-
return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
210
|
-
|
211
|
-
class DivNode(OpNode):
|
212
|
-
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
213
|
-
def get_bounds(self) -> Tuple[int, sint]:
|
214
|
-
assert self.a.min >= 0 and isinstance(self.b, int)
|
215
|
-
return self.a.min//self.b, self.a.max//self.b
|
216
|
-
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
|
217
|
-
|
218
|
-
class ModNode(OpNode):
|
219
|
-
def __mod__(self, b: Union[Node, int]):
|
220
|
-
if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
|
221
|
-
return Node.__mod__(self, b)
|
222
|
-
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
223
|
-
return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
|
224
|
-
def get_bounds(self) -> Tuple[int, sint]:
|
225
|
-
assert self.a.min >= 0 and isinstance(self.b, int)
|
226
|
-
if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
|
227
|
-
return (self.a.min%self.b, self.a.max%self.b)
|
228
|
-
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
|
229
|
-
|
230
|
-
class RedNode(Node):
|
231
|
-
def __init__(self, nodes:List[Node]):
|
232
|
-
self.nodes = nodes
|
233
|
-
self.min, self.max = self.get_bounds()
|
234
|
-
def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
|
235
|
-
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
236
|
-
|
237
|
-
class SumNode(RedNode):
|
238
|
-
def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
|
239
|
-
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
240
|
-
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
241
|
-
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
242
|
-
def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
|
243
|
-
if self == b: return NumNode(1)
|
244
|
-
fully_divided: List[Node] = []
|
245
|
-
rest: List[Node] = []
|
246
|
-
if isinstance(b, Node):
|
247
|
-
for x in self.flat_components:
|
248
|
-
if x % b == 0: fully_divided.append(x // b)
|
249
|
-
else: rest.append(x)
|
250
|
-
if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
|
251
|
-
return Node.__floordiv__(self, b, False)
|
252
|
-
if b == 1: return self
|
253
|
-
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
254
|
-
_gcd = b
|
255
|
-
divisor = 1
|
256
|
-
for x in self.flat_components:
|
257
|
-
if x.__class__ in (NumNode, MulNode):
|
258
|
-
if x.b % b == 0: fully_divided.append(x // b)
|
259
|
-
else:
|
260
|
-
if x.__class__ is NumNode and (div := x.b // b):
|
261
|
-
fully_divided.append(NumNode(div))
|
262
|
-
x = NumNode(x.b - b * div)
|
263
|
-
rest.append(x)
|
264
|
-
if isinstance(x.b, int):
|
265
|
-
_gcd = gcd(_gcd, x.b)
|
266
|
-
if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
|
267
|
-
else:
|
268
|
-
_gcd = 1
|
269
|
-
else:
|
270
|
-
rest.append(x)
|
271
|
-
_gcd = 1
|
272
|
-
if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
|
273
|
-
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
|
274
|
-
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
|
275
|
-
|
276
|
-
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
277
|
-
def __mod__(self, b: Union[Node, int]):
|
278
|
-
if self == b: return NumNode(0)
|
279
|
-
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
280
|
-
new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
|
281
|
-
return Node.__mod__(new_sum, b)
|
282
|
-
|
283
|
-
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
284
|
-
return Node.sum([node.substitute(var_vals) for node in self.nodes])
|
285
|
-
|
286
|
-
# recursively expand sumnode components
|
287
|
-
# TODO: can remove this if there's no SumNode inside SumNode
|
288
|
-
@property
|
289
|
-
def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
|
290
|
-
|
291
|
-
class AndNode(RedNode):
|
292
|
-
def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
|
293
|
-
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
294
|
-
subed = []
|
295
|
-
for node in self.nodes:
|
296
|
-
if not (sub:=node.substitute(var_vals)): return NumNode(0)
|
297
|
-
subed.append(sub)
|
298
|
-
return Node.ands(subed)
|
299
|
-
|
300
|
-
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
301
|
-
def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
302
|
-
if isinstance(a, (int, float)): return a
|
303
|
-
ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
|
304
|
-
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
|
305
|
-
return ret.b
|
306
|
-
|
307
|
-
# symbolic int, these are allowed in a Tensor shape
|
308
|
-
sint = Union[int, Variable, MulNode, SumNode]
|
309
|
-
|
310
|
-
def render_mulnode(node:MulNode, ops, ctx):
|
311
|
-
# TODO: add ProdNode and remove this case
|
312
|
-
if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
|
313
|
-
return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
|
314
|
-
return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
|
315
|
-
|
316
|
-
render_python: Dict[Type, Callable[..., str]] = {
|
317
|
-
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
|
318
|
-
else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
|
319
|
-
else f"{self.expr}"),
|
320
|
-
NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
|
321
|
-
MulNode: render_mulnode,
|
322
|
-
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
323
|
-
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
324
|
-
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
325
|
-
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
326
|
-
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
327
|
-
}
|