tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Optional,
|
2
|
+
from typing import Optional, Any, Callable
|
3
3
|
import functools, itertools, operator
|
4
4
|
from collections import defaultdict
|
5
5
|
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
6
|
-
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
|
6
|
+
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve
|
7
7
|
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
|
8
8
|
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
9
9
|
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
10
|
-
|
11
|
-
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
10
|
+
from tinygrad.renderer import Renderer
|
12
11
|
|
13
12
|
# ***** float4/image store handling *****
|
14
13
|
|
@@ -19,7 +18,7 @@ def fold_expanded(ex, buf):
|
|
19
18
|
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
|
20
19
|
|
21
20
|
# first, extract all the relevant offsets
|
22
|
-
offsets_rootsrc:
|
21
|
+
offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
|
23
22
|
for i,s in enumerate(new_srcs):
|
24
23
|
idx = s.src[0].src[1]
|
25
24
|
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
|
@@ -33,7 +32,7 @@ def fold_expanded(ex, buf):
|
|
33
32
|
|
34
33
|
# then rewrite everything we can
|
35
34
|
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
|
36
|
-
used:
|
35
|
+
used: set[tuple[UOp, UOp]] = set()
|
37
36
|
for rootsrc, offsets in offsets_rootsrc.items():
|
38
37
|
for o in offsets:
|
39
38
|
for fold_length in lengths:
|
@@ -49,7 +48,8 @@ def fold_expanded(ex, buf):
|
|
49
48
|
rootsrc[0] if isinstance(rootsrc, tuple) else None)
|
50
49
|
else:
|
51
50
|
# for non image, we upcast the index pointer
|
52
|
-
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(new_src[0].dtype.
|
51
|
+
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size//fold_length,
|
52
|
+
local=new_src[0].dtype.local))
|
53
53
|
# generate the folded new_srcs
|
54
54
|
if is_load:
|
55
55
|
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
@@ -83,7 +83,7 @@ float4_folding = PatternMatcher([
|
|
83
83
|
|
84
84
|
# ***** image load valid simplification *****
|
85
85
|
|
86
|
-
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) ->
|
86
|
+
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
87
87
|
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
|
88
88
|
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
89
89
|
|
@@ -122,19 +122,17 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
|
122
122
|
powers_of_two = {2**i:i for i in range(64)}
|
123
123
|
@functools.lru_cache(None)
|
124
124
|
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
125
|
-
pat:
|
125
|
+
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
126
126
|
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
127
|
-
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
|
127
|
+
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
128
128
|
if Ops.AND in ops:
|
129
|
-
pat += [(UPat(
|
130
|
-
|
131
|
-
# rewrite MUL/IDIV to SHL+SHR
|
129
|
+
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
130
|
+
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
132
131
|
if Ops.SHL in ops and Ops.SHR in ops:
|
133
132
|
pat += [
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
|
133
|
+
(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << powers_of_two[c.arg] if c.arg in powers_of_two else None),
|
134
|
+
(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> powers_of_two[c.arg] if c.arg in powers_of_two and resolve(x>=0,False) else None)
|
135
|
+
]
|
138
136
|
if Ops.NEG in ops:
|
139
137
|
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
140
138
|
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
@@ -191,7 +189,7 @@ def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extr
|
|
191
189
|
|
192
190
|
def index_collapse(idx:UOp,rng:UOp,buf:UOp,ld:UOp,acc:UOp,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
|
193
191
|
if rng not in acc.src: return None
|
194
|
-
new_load = UOp.load(buf.index(add+mul*idx, idx
|
192
|
+
new_load = UOp.load(buf.index(add+mul*idx, (idx >= rng.src[0]) & (idx < rng.src[1])), dtype=ld.dtype)
|
195
193
|
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
|
196
194
|
return new_acc.assign(new_acc+new_load)
|
197
195
|
|
@@ -221,7 +219,7 @@ def no_vectorized_wmma(wmma:UOp):
|
|
221
219
|
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
222
220
|
|
223
221
|
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
224
|
-
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.
|
222
|
+
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort)
|
225
223
|
if len(reduce_unparented) == 0: return None
|
226
224
|
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
|
227
225
|
ret = new_acc.assign(new_acc.alu(alu.op, ret))
|
@@ -235,17 +233,18 @@ rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UP
|
|
235
233
|
index_load = UPat.var("buf").index(rng_aug).load(name="ld")
|
236
234
|
|
237
235
|
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
238
|
-
arange_m = arange_augrng
|
236
|
+
arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
|
237
|
+
|
238
|
+
# this moves the accumulation variable down an unrolled add chain which allows for more efficient accumulation using mulacc
|
239
|
+
mulacc_unrolled = PatternMatcher([(UPat.var("x")+UPat.var("y")+acc_pat, lambda x,y,acc: (acc+x)+y if y.op is not Ops.DEFINE_ACC else None)])
|
239
240
|
|
240
241
|
# this is symbolic 2.0
|
241
242
|
sym = symbolic_flat+PatternMatcher([
|
242
243
|
# self ASSIGN is just self
|
243
244
|
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
244
|
-
# ASSIGN to global is just self
|
245
|
-
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
|
246
245
|
# VECTORIZE/CONST, VECTORIZE/GEP
|
247
246
|
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
248
|
-
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(
|
247
|
+
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
249
248
|
# reorder ALU/VECTORIZE
|
250
249
|
(UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
|
251
250
|
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
|
@@ -288,14 +287,16 @@ sym = symbolic_flat+PatternMatcher([
|
|
288
287
|
# indexing, with cast or where
|
289
288
|
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
290
289
|
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
291
|
-
# parentless reduce
|
292
|
-
(acc_pat.assign(UPat(Ops.ADD, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
293
|
-
(acc_pat.assign(UPat(Ops.MAX, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
290
|
+
# parentless reduce # TODO: add MUL
|
291
|
+
(acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
294
292
|
# ** self folding **
|
295
293
|
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
296
294
|
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
297
295
|
# x!=0 -> (bool)x
|
298
|
-
(UPat.var("x")
|
296
|
+
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
|
297
|
+
# ** where **
|
298
|
+
# push cast to branches
|
299
|
+
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
|
299
300
|
# ** load/store folding **
|
300
301
|
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
301
302
|
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
|
@@ -308,31 +309,36 @@ sym = symbolic_flat+PatternMatcher([
|
|
308
309
|
# remove NOOPs from SINK
|
309
310
|
(UPat(Ops.SINK, name="root"),
|
310
311
|
lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
|
311
|
-
# remove
|
312
|
+
# remove VECTORIZE from SINK/BARRIER
|
312
313
|
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
|
313
314
|
(UPat(Ops.SINK, name="root"),
|
314
|
-
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.
|
315
|
-
if any(x.op in {Ops.SINK, Ops.
|
315
|
+
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
|
316
|
+
if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
|
317
|
+
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
|
318
|
+
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
|
319
|
+
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
|
320
|
+
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
|
321
|
+
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
|
316
322
|
])
|
317
323
|
|
318
324
|
# *** uop expander ***
|
319
325
|
|
320
|
-
def _expand_arg_to_idx(args:
|
326
|
+
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
|
321
327
|
idx, mul = 0, 1
|
322
328
|
for axis,m in args[::-1]:
|
323
329
|
idx += rpk[axis] * mul
|
324
330
|
mul *= m
|
325
331
|
return idx
|
326
332
|
|
327
|
-
def _choices_from_args(args:
|
333
|
+
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
|
328
334
|
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
329
335
|
|
330
336
|
@functools.lru_cache(None)
|
331
|
-
def _swizzle_args(cargs:
|
337
|
+
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
|
332
338
|
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
333
339
|
|
334
340
|
def do_expand(root:UOp):
|
335
|
-
expands = [x for x in root.src if x.op is Ops.
|
341
|
+
expands = [x for x in root.src if x.op is Ops.UNROLL]
|
336
342
|
if len(expands) == 0: return None
|
337
343
|
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
|
338
344
|
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
|
@@ -345,7 +351,7 @@ def do_expand(root:UOp):
|
|
345
351
|
expand_sz = prod([x[1] for x in expand_args])
|
346
352
|
new_srcs = []
|
347
353
|
for i,src in enumerate(root.src):
|
348
|
-
if src.op is Ops.
|
354
|
+
if src.op is Ops.UNROLL:
|
349
355
|
if root.op is Ops.IF and i == 0:
|
350
356
|
# IF means OR on first arg to IF
|
351
357
|
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
|
@@ -358,9 +364,9 @@ def do_expand(root:UOp):
|
|
358
364
|
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
|
359
365
|
new_srcs.append(src.src[0].gep(tuple(lst)))
|
360
366
|
else:
|
361
|
-
# non-
|
367
|
+
# non-UNROLL input
|
362
368
|
if root.op is Ops.IF:
|
363
|
-
# for the first arg of IF, just pass them through ignoring
|
369
|
+
# for the first arg of IF, just pass them through ignoring UNROLLS
|
364
370
|
new_srcs.append(src)
|
365
371
|
elif src.dtype.count > 1:
|
366
372
|
# put any input dtype > 1 grouped together
|
@@ -376,25 +382,25 @@ def do_expand(root:UOp):
|
|
376
382
|
# is this right?
|
377
383
|
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
378
384
|
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
379
|
-
return UOp(Ops.
|
385
|
+
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
|
380
386
|
|
381
387
|
def do_contract(con:UOp):
|
382
388
|
ex = con.src[0]
|
383
|
-
# CONTRACT without
|
384
|
-
if ex.op is not Ops.
|
385
|
-
# CONTRACT may remove several axes from
|
389
|
+
# CONTRACT without UNROLL repeats the element VECTORIZED
|
390
|
+
if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
391
|
+
# CONTRACT may remove several axes from UNROLL
|
386
392
|
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
|
387
393
|
idxs = []
|
388
394
|
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
|
389
395
|
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
|
390
|
-
return UOp(Ops.
|
396
|
+
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
391
397
|
|
392
398
|
def no_vectorized_alu(alu):
|
393
399
|
if alu.dtype.vcount == 1: return None
|
394
400
|
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
395
401
|
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
396
402
|
|
397
|
-
def create_gate(root:UOp) ->
|
403
|
+
def create_gate(root:UOp) -> UOp|None:
|
398
404
|
@functools.lru_cache(None)
|
399
405
|
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
400
406
|
if u.op is Ops.BARRIER: return u
|
@@ -407,22 +413,22 @@ def create_gate(root:UOp) -> Optional[UOp]:
|
|
407
413
|
|
408
414
|
expander = PatternMatcher([
|
409
415
|
# double expand
|
410
|
-
(UPat(Ops.
|
411
|
-
lambda outer, inner: UOp(Ops.
|
416
|
+
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
417
|
+
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
412
418
|
# do expansion
|
413
419
|
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
414
|
-
Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.
|
420
|
+
Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
415
421
|
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
416
422
|
# vectorize DEFINE_ACC
|
417
423
|
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
|
418
424
|
# BARRIERs aren't actually expanded
|
419
|
-
(UPat(Ops.BARRIER, src=(UPat(Ops.
|
420
|
-
lambda ex: UOp(Ops.
|
421
|
-
# empty
|
422
|
-
(UPat(Ops.
|
423
|
-
#
|
424
|
-
(UPat(Ops.
|
425
|
-
lambda ex,x,y: UOp(Ops.
|
425
|
+
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
426
|
+
lambda ex: UOp(Ops.UNROLL, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
|
427
|
+
# empty UNROLL is NOOP
|
428
|
+
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
|
429
|
+
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
|
430
|
+
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
431
|
+
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
426
432
|
])
|
427
433
|
|
428
434
|
def no_vectorized_load_store(ls:UOp):
|
@@ -446,8 +452,8 @@ devectorize = PatternMatcher([
|
|
446
452
|
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
447
453
|
])
|
448
454
|
|
449
|
-
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:
|
450
|
-
if store_gate not in [gate.src[0] for gate in val.
|
455
|
+
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
456
|
+
if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
|
451
457
|
# remove the gate from the index
|
452
458
|
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
453
459
|
|
@@ -468,7 +474,7 @@ migrate_indexing = PatternMatcher([
|
|
468
474
|
(UPat(Ops.STORE, name="root"), create_gate),
|
469
475
|
])
|
470
476
|
|
471
|
-
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:
|
477
|
+
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
|
472
478
|
# this moves the mask from the indexing to the load/store op for rendering
|
473
479
|
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
|
474
480
|
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
|
@@ -481,8 +487,11 @@ pm_render = PatternMatcher([
|
|
481
487
|
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
482
488
|
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
483
489
|
# move masks of loads/stores
|
484
|
-
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(
|
490
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))),
|
485
491
|
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
492
|
+
# gate any stores that aren't gated with ifs
|
493
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
494
|
+
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
486
495
|
])
|
487
496
|
|
488
497
|
# *** uop graph ***
|
@@ -498,8 +507,9 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|
498
507
|
# expand
|
499
508
|
sink = graph_rewrite(sink, sym+expander)
|
500
509
|
|
501
|
-
# devectorize + load_store_indexing
|
502
|
-
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing
|
510
|
+
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
|
511
|
+
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
|
512
|
+
mulacc_unrolled)
|
503
513
|
|
504
514
|
# final rules for the renderer (without sym)
|
505
515
|
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import math
|
2
|
-
from typing import Tuple
|
3
2
|
from tinygrad.dtype import dtypes, DType
|
4
3
|
from tinygrad.helpers import polyN
|
5
4
|
from tinygrad.ops import UOp
|
@@ -22,7 +21,7 @@ def shl(x:UOp, y:int) -> UOp: return x * (2**y)
|
|
22
21
|
def rintk(d:UOp) -> UOp:
|
23
22
|
"""round d:float to int away from 0"""
|
24
23
|
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
|
25
|
-
return (d + d
|
24
|
+
return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
|
26
25
|
|
27
26
|
def pow2if(q:UOp, float_dtype:DType):
|
28
27
|
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
|
@@ -49,7 +48,7 @@ def ldexp2k(d:UOp, e:UOp) -> UOp:
|
|
49
48
|
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
|
50
49
|
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
|
51
50
|
|
52
|
-
def frexp(v:UOp) ->
|
51
|
+
def frexp(v:UOp) -> tuple[UOp, UOp]:
|
53
52
|
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
|
54
53
|
assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
55
54
|
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
|
@@ -63,7 +62,7 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]:
|
|
63
62
|
return mantissa, exp
|
64
63
|
|
65
64
|
# *** reduction algorithms for sine ***
|
66
|
-
def payne_hanek_reduction(d:UOp) ->
|
65
|
+
def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
|
67
66
|
"""
|
68
67
|
Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
|
69
68
|
39800.0 <= d <= +Inf
|
@@ -110,9 +109,9 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
|
110
109
|
r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
|
111
110
|
|
112
111
|
# if fraction >= 0.5, r -= pi/2, q += 1
|
113
|
-
return f
|
112
|
+
return (f<0.5).where(r, r - math.pi/2), (f<0.5).where(q, q + 1)
|
114
113
|
|
115
|
-
def cody_waite_reduction(d:UOp) ->
|
114
|
+
def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
|
116
115
|
"""
|
117
116
|
Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
|
118
117
|
0 <= abs(d) <= 39800.0
|
@@ -177,14 +176,14 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
|
177
176
|
# mask +-inf/nan as zero
|
178
177
|
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
179
178
|
# x_sign = sign(x)
|
180
|
-
x_sign = x.ne(0).where(x
|
179
|
+
x_sign = x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
|
181
180
|
x_abs = x * x_sign
|
182
181
|
r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs)
|
183
182
|
if fast: result = sin_poly_small(r, q)
|
184
183
|
else:
|
185
184
|
# Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
|
186
185
|
r_small, q_small = cody_waite_reduction(x_abs)
|
187
|
-
result = x_abs
|
186
|
+
result = (x_abs<switch_over).where(sin_poly_small(r_small, q_small), sin_poly_large(r, q))
|
188
187
|
# adjusts the sign for abs(x)
|
189
188
|
result = result * x_sign
|
190
189
|
# sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
|
@@ -210,9 +209,9 @@ def xexp2(d:UOp) -> UOp:
|
|
210
209
|
u = ldexp2k(u, q) # u*2^q
|
211
210
|
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
|
212
211
|
# Replace x >= upper with +inf
|
213
|
-
u = d
|
214
|
-
# Replace x
|
215
|
-
u = d
|
212
|
+
u = (d >= upper).where(d.const_like(math.inf), u)
|
213
|
+
# Replace x < lower with zero.
|
214
|
+
u = (d<lower).where(d.const_like(0.0), u)
|
216
215
|
# exp2(NaN) = NaN
|
217
216
|
return d.ne(d).where(d.const_like(math.nan), u)
|
218
217
|
|
@@ -225,7 +224,7 @@ def xlog2(d:UOp) -> UOp:
|
|
225
224
|
# TODO: float16 denormal need float32 to achieve precision
|
226
225
|
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
227
226
|
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
228
|
-
is_denormal = d
|
227
|
+
is_denormal = d<FLT_MIN
|
229
228
|
a = is_denormal.where(d * (2 ** 64), d)
|
230
229
|
|
231
230
|
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
|
@@ -246,7 +245,7 @@ def xlog2(d:UOp) -> UOp:
|
|
246
245
|
# log2(Inf) = Inf
|
247
246
|
r = d.ne(math.inf).where(r, r.const_like(math.inf))
|
248
247
|
# log2(x) = NaN for x < 0
|
249
|
-
r = d
|
248
|
+
r = (d<-0.0).where(r.const_like(math.nan), r)
|
250
249
|
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
251
250
|
# log2_zero = the value of unmasked xlog2(0.0).
|
252
251
|
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
|