tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,119 @@
|
|
1
|
+
from tinygrad.uop.ops import Ops, UOp, resolve, can_pad, GroupOp, UPat, PatternMatcher, graph_rewrite
|
2
|
+
from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, FUSE_CONV_BW
|
3
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
4
|
+
|
5
|
+
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
6
|
+
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
7
|
+
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
|
8
|
+
|
9
|
+
# **** Grouper decides which of the UOps realize
|
10
|
+
|
11
|
+
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
12
|
+
|
13
|
+
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
|
14
|
+
for s in rb.src:
|
15
|
+
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
16
|
+
|
17
|
+
def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
|
18
|
+
st = unwrap(view.st)
|
19
|
+
# always realize unsafe pad ops before masked view
|
20
|
+
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx): return realize(ctx, tr)
|
21
|
+
# fold simple pads
|
22
|
+
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
|
23
|
+
# realize before expand
|
24
|
+
if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr)
|
25
|
+
|
26
|
+
do_realize = PatternMatcher([
|
27
|
+
# always realize SINK parents
|
28
|
+
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
29
|
+
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
30
|
+
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
31
|
+
# realize before expand or unsafe pad ops
|
32
|
+
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
|
33
|
+
# realize parents of COPY, MSELECT, MSTACK
|
34
|
+
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
|
35
|
+
])
|
36
|
+
|
37
|
+
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
|
38
|
+
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
39
|
+
if (tr, st) in cache: return
|
40
|
+
cache.setdefault((tr, st))
|
41
|
+
rsize = unwrap(r.st).size
|
42
|
+
if tr in realizes and tr is not r:
|
43
|
+
# can only fuse contiguous
|
44
|
+
# max one reduceop per kernel
|
45
|
+
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
|
46
|
+
return group.setdefault(tr)
|
47
|
+
for tr_next in children.get(tr, {}):
|
48
|
+
# max one reduceop per kernel
|
49
|
+
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
50
|
+
# can only fuse contiguous
|
51
|
+
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
|
52
|
+
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
|
53
|
+
|
54
|
+
def group_realizes(sink:UOp) -> dict[UOp, None]:
|
55
|
+
# start by adding uops that always realize
|
56
|
+
realizes: dict[UOp, None] = {}
|
57
|
+
sink = graph_rewrite(sink, do_realize, ctx=realizes, name="do_realize")
|
58
|
+
if DONT_GROUP_REDUCES: return realizes
|
59
|
+
|
60
|
+
# construct children graph (only for bases)
|
61
|
+
children: dict[UOp, dict[UOp, None]] = {}
|
62
|
+
assigns: dict[UOp, None] = {}
|
63
|
+
for u in (toposort:=sink.toposort()):
|
64
|
+
if u.op in {Ops.VIEW, Ops.SINK}: continue
|
65
|
+
if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
|
66
|
+
for s in u.src: children.setdefault(s.base, {})[u] = None
|
67
|
+
|
68
|
+
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
69
|
+
reduce_for_op: dict[UOp, UOp] = {}
|
70
|
+
double_reduces: list[UOp] = []
|
71
|
+
for r in toposort:
|
72
|
+
if r.op is not Ops.REDUCE_AXIS: continue
|
73
|
+
if len(r.arg) == 3 and r.arg[2] is True: continue
|
74
|
+
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
|
75
|
+
if r in realizes: continue
|
76
|
+
group: dict[UOp, None] = {}
|
77
|
+
recursive_group(r, unwrap(r.st), r, children, realizes, reduce_for_op, group, cache={})
|
78
|
+
# max one reduceop per kernel
|
79
|
+
can_chase = all(tr not in reduce_for_op for tr in group)
|
80
|
+
for u in r.toposort(gate=lambda u: u not in realizes):
|
81
|
+
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST:
|
82
|
+
can_chase = False
|
83
|
+
break
|
84
|
+
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
85
|
+
forced_realize = r in group
|
86
|
+
# can only have one output
|
87
|
+
if not forced_realize and len(group) > 1: forced_realize = True
|
88
|
+
# can only fuse assign if no other assign_target is used in the kernel
|
89
|
+
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
|
90
|
+
parents = [r, *group]
|
91
|
+
while parents and not forced_realize:
|
92
|
+
p = parents.pop().base
|
93
|
+
if p.op is Ops.BUFFER and p in assigns and p not in assign_targets: forced_realize, can_chase = True, False
|
94
|
+
if p in realizes: continue
|
95
|
+
parents.extend(p.src)
|
96
|
+
if forced_realize or not group:
|
97
|
+
tr = r
|
98
|
+
if can_chase:
|
99
|
+
# can chase this down to contiguous children
|
100
|
+
st = unwrap(tr.st)
|
101
|
+
while len(lst:=children.get(tr, {})) == 1:
|
102
|
+
tr_next = next(iter(lst))
|
103
|
+
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
|
104
|
+
if len(st_childs) > 1: break
|
105
|
+
if st.size != st_childs[0].size: break
|
106
|
+
st = st + st_childs[0]
|
107
|
+
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
|
108
|
+
tr = tr_next
|
109
|
+
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
110
|
+
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
|
111
|
+
tr = tr.src[0].base
|
112
|
+
group = {tr: None}
|
113
|
+
realizes[tr] = None
|
114
|
+
reduce_for_op.update((tr, r) for tr in group)
|
115
|
+
# fuse double reduces with no other child
|
116
|
+
for reduceop in double_reduces:
|
117
|
+
top_reduce = reduceop.src[0].base
|
118
|
+
if len(children.get(top_reduce, {})) == 1: del realizes[top_reduce]
|
119
|
+
return realizes
|
@@ -0,0 +1,368 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
3
|
+
from tinygrad.uop.ops import track_rewrites, _substitute
|
4
|
+
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
5
|
+
from tinygrad.uop.symbolic import symbolic_simple
|
6
|
+
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
7
|
+
from tinygrad.dtype import ImageDType
|
8
|
+
from tinygrad.schedule.multi import multi_pm
|
9
|
+
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
10
|
+
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
11
|
+
|
12
|
+
# creation can recurse a lot
|
13
|
+
import sys
|
14
|
+
sys.setrecursionlimit(10000)
|
15
|
+
|
16
|
+
# **** schedule simplifier
|
17
|
+
|
18
|
+
def simplify_stride0_reduce(reduce:UOp, x:UOp):
|
19
|
+
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
|
20
|
+
if any(v.mask is not None for v in unwrap(x.st).views): return None
|
21
|
+
# must have all stride 0 in the relevant axis (NOTE: can do partial)
|
22
|
+
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
|
23
|
+
prshape = prod(x.shape[i] for i in reduce.arg[1])
|
24
|
+
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
|
25
|
+
match reduce.arg[0]:
|
26
|
+
case Ops.ADD: return ret*prshape
|
27
|
+
case Ops.MUL: return ret.pow(prshape)
|
28
|
+
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
|
29
|
+
|
30
|
+
def split_reduceop(reduce:UOp, x:UOp):
|
31
|
+
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
|
32
|
+
# if there are few globals, make some reduces into globals by splitting into two kernels
|
33
|
+
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
34
|
+
# ~2**10 should be enough if GROUP is used
|
35
|
+
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
|
36
|
+
# split is moved to the end to provide maximum locality for the second phase reduce.
|
37
|
+
real_strides = unwrap(x.st).real_strides(ignore_valid=True)
|
38
|
+
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
|
39
|
+
if x.shape[i]%d==0 and real_strides[i]!=0]): return None
|
40
|
+
dim_to_split, divisor = split_candidates[0]
|
41
|
+
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
|
42
|
+
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
|
43
|
+
if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
|
44
|
+
# reduce original axes, then split
|
45
|
+
return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
|
46
|
+
|
47
|
+
def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
|
48
|
+
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
|
49
|
+
return base.copy_to_device(copy.device).view(view.arg)
|
50
|
+
|
51
|
+
kernelize_sym = symbolic_simple+PatternMatcher([
|
52
|
+
# UOp with size 0 is zero
|
53
|
+
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
|
54
|
+
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
55
|
+
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
56
|
+
# reduce of size 0 is the identity element
|
57
|
+
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
58
|
+
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
59
|
+
# reduce on stride 0 is collapsed
|
60
|
+
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
|
61
|
+
# split_reduceop
|
62
|
+
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
63
|
+
# COPY(CONST) creates a new CONST on the destination device
|
64
|
+
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)),
|
65
|
+
# non device changing COPY is a NOOP
|
66
|
+
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
|
67
|
+
# store a shrink before COPY, otherwise view after the COPY
|
68
|
+
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
|
69
|
+
# remove cast to image when it's already a contiguous image
|
70
|
+
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
|
71
|
+
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
72
|
+
# CAST before masking constants
|
73
|
+
(UPat.cvar("x").view().cast(name="c"), lambda x,c: x.cast(c.dtype).view(c.src[0].arg)),
|
74
|
+
# make things that can't be images not images
|
75
|
+
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
|
76
|
+
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
|
77
|
+
# remove contiguous if we can just view the buffer
|
78
|
+
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
79
|
+
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
80
|
+
# contiguous/buffer/copy/assign is already contiguous
|
81
|
+
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
|
82
|
+
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
83
|
+
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,),
|
84
|
+
(t.size, x.st.views[0].offset)).reshape(t.shape) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
85
|
+
# double ASSIGN to same target is one ASSIGN
|
86
|
+
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))))), lambda x,t: t.assign(x.contiguous())),
|
87
|
+
# ASSIGN to unrealized replaces the UOp
|
88
|
+
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))), lambda x,t: x.contiguous() if t.base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW} and
|
89
|
+
not (t.base.op is Ops.MSTACK and all(x.op is Ops.BUFFER for x in t.base.src)) else None),
|
90
|
+
# put CAST to smaller dtype before EXPAND
|
91
|
+
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
|
92
|
+
if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None),
|
93
|
+
# put UnaryOps before EXPANDs, if it can fuse with the input
|
94
|
+
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"),
|
95
|
+
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
|
96
|
+
])
|
97
|
+
|
98
|
+
# support for using a contiguous permuted view instead of the parent view if one exists
|
99
|
+
|
100
|
+
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
101
|
+
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
102
|
+
|
103
|
+
replace_contiguous = PatternMatcher([
|
104
|
+
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, name="src"),), name="contig"), found_contiguous),
|
105
|
+
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
106
|
+
])
|
107
|
+
|
108
|
+
# **** create kernels
|
109
|
+
|
110
|
+
@dataclass(frozen=True)
|
111
|
+
class Kernel:
|
112
|
+
ast: UOp
|
113
|
+
metadata: tuple[Metadata, ...] = ()
|
114
|
+
def __repr__(self):
|
115
|
+
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
|
116
|
+
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
|
117
|
+
|
118
|
+
def create_kernel(x:UOp, b:UOp|None=None):
|
119
|
+
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
120
|
+
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
|
121
|
+
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
122
|
+
return buffer.assign(kernel).reshape(x.shape)
|
123
|
+
|
124
|
+
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND}
|
125
|
+
def append_to_kernel(x:UOp):
|
126
|
+
new_srcs: list[UOp] = []
|
127
|
+
metadata = x.arg.metadata
|
128
|
+
for s in x.src:
|
129
|
+
if s.op in DONT_PLACE_IN_KERNEL: new_srcs.append(s)
|
130
|
+
else:
|
131
|
+
new_srcs.extend(s.src)
|
132
|
+
# NOTE: because const and device are shared UOps they don't change metadata
|
133
|
+
# NOTE: if it's a reshape after ASSIGN we're not fusing that parent kernel
|
134
|
+
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (not (s.op is Ops.RESHAPE and s.base.op is Ops.ASSIGN)) and (m:=s.metadata): metadata += m
|
135
|
+
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata))))
|
136
|
+
|
137
|
+
create_kernels = PatternMatcher([
|
138
|
+
# always give assign/contiguous a kernel
|
139
|
+
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
|
140
|
+
(UPat(Ops.CONTIGUOUS, name="x"), create_kernel),
|
141
|
+
# walk back the local graph until we reach a realized source
|
142
|
+
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
143
|
+
# push RESHAPE through MSELECT
|
144
|
+
(UPat(Ops.MSELECT, src=(UPat(Ops.RESHAPE, name="r"),), name="ms"), lambda ms,r: r.src[0].mselect(ms.arg).reshape(r.arg)),
|
145
|
+
# push RESHAPE through MSTACK
|
146
|
+
(UPat(Ops.MSTACK, src=UPat(Ops.RESHAPE), name="ms"),
|
147
|
+
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
|
148
|
+
])
|
149
|
+
|
150
|
+
# **** fix kernel AST
|
151
|
+
|
152
|
+
def unbind_view(x:UOp):
|
153
|
+
if any(x.op is Ops.BIND for x in x.arg.vars()): return x.replace(arg=x.arg.unbind()[0])
|
154
|
+
return None
|
155
|
+
|
156
|
+
replace_buffers = PatternMatcher([
|
157
|
+
# replace ASSIGN with the target BUFFER
|
158
|
+
(UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
|
159
|
+
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
|
160
|
+
(UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]),
|
161
|
+
# LOAD
|
162
|
+
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).load()),
|
163
|
+
# no SINK for meta ops
|
164
|
+
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
165
|
+
# STORE (except for meta ops)
|
166
|
+
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
|
167
|
+
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)],
|
168
|
+
arg=sink.arg)),
|
169
|
+
# remove CONTIGUOUS/DEVICE from kernel AST
|
170
|
+
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
171
|
+
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
172
|
+
# passthrough ASSIGN (but let MSTACK process first)
|
173
|
+
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.MSTACK}), UPat()), name="x"), lambda x: x.src[1]),
|
174
|
+
# remove any BINDs from VIEWS
|
175
|
+
(UPat(Ops.VIEW, src=(UPat(), UPat((Ops.BIND, Ops.DEFINE_VAR))), allow_any_len=True, name="x"), lambda x: x.replace(src=x.src[0:1])),
|
176
|
+
# remove any BINDs from DEFINE_VARs
|
177
|
+
(UPat(Ops.BIND, name="x"), lambda x: x.src[0]),
|
178
|
+
# remove BINDs from ShapeTrackers
|
179
|
+
(UPat(Ops.VIEW, name="x"), unbind_view),
|
180
|
+
])
|
181
|
+
|
182
|
+
def fix_kernel_ast(k:UOp) -> UOp|None:
|
183
|
+
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
|
184
|
+
# replace buffer with define_global + add load/store last
|
185
|
+
bufs = []
|
186
|
+
for s in k.src:
|
187
|
+
if s.op is Ops.BIND: continue
|
188
|
+
s = s.buf_uop
|
189
|
+
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
190
|
+
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
191
|
+
bufs.append(s)
|
192
|
+
# replace global memory ops with the BUFFER they write to
|
193
|
+
# NOTE: merge_views is needed to unbind the reshapes
|
194
|
+
ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers")
|
195
|
+
if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]):
|
196
|
+
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
197
|
+
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
198
|
+
|
199
|
+
create_ast = PatternMatcher([
|
200
|
+
(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),
|
201
|
+
(UPat(Ops.DEFINE_VAR, src=(UPat(),), allow_any_len=True, name="x"), lambda x: x.replace(src=())),
|
202
|
+
])
|
203
|
+
|
204
|
+
# ** add metadata of KERNEL outputs
|
205
|
+
|
206
|
+
def append_metadata(root:UOp, k:UOp):
|
207
|
+
if not root.metadata or (new_metadata:=tuple(dedup(k.arg.metadata+root.metadata))) == k.arg.metadata: return None
|
208
|
+
return root.replace(src=(root.src[0], k.replace(arg=Kernel(k.arg.ast, new_metadata)))+root.src[2:])
|
209
|
+
|
210
|
+
replace_metadata = PatternMatcher([(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.KERNEL, name="k")), name="root", allow_any_len=True), append_metadata),])
|
211
|
+
|
212
|
+
pm_fuse = PatternMatcher([
|
213
|
+
# FUSE on CONTIGUOUS removes FUSE
|
214
|
+
(UPat(Ops.CONTIGUOUS, name="c").fuse(), lambda c: c),
|
215
|
+
|
216
|
+
# FUSE triggers swizzle on reduceop
|
217
|
+
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").or_casted(),), name="view").fuse(),
|
218
|
+
lambda r,src,view: ret.cast(view.dtype) if (ret:=swizzle_reduceop(r, src, view, fuse=True)) is not None else None),
|
219
|
+
|
220
|
+
# FUSE on reduce (without view) adds fuse marker to grouper
|
221
|
+
(UPat(Ops.REDUCE_AXIS, name="r").fuse(),
|
222
|
+
lambda r: r.replace(src=(r.src[0].fuse(),), arg=r.arg+(True,)) if len(r.arg) == 2 else None),
|
223
|
+
|
224
|
+
# remove FUSE and insert CONTIGUOUS if it's an unsafe pad
|
225
|
+
(UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="alu"),), name="view").fuse(),
|
226
|
+
lambda alu, view: alu.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
|
227
|
+
|
228
|
+
# FUSE elementwise.
|
229
|
+
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(),
|
230
|
+
lambda alu, view: alu.replace(src=tuple(apply_swizzle(x.view(view.arg)).fuse() for x in alu.src))),
|
231
|
+
|
232
|
+
# push FUSE through to srcs
|
233
|
+
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
|
234
|
+
])
|
235
|
+
|
236
|
+
def do_fusion(x:UOp):
|
237
|
+
found_contiguous = {}
|
238
|
+
def gate_contiguous(x):
|
239
|
+
if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique()))
|
240
|
+
return not is_contiguous
|
241
|
+
x.toposort(gate=gate_contiguous)
|
242
|
+
del gate_contiguous
|
243
|
+
return graph_rewrite(x.substitute(found_contiguous), pm_fuse, name="local fusion").substitute({v:k for k,v in found_contiguous.items()})
|
244
|
+
|
245
|
+
def fuse_arange(root:UOp):
|
246
|
+
# skip if root is arange
|
247
|
+
if not FUSE_ARANGE or root.src[0].base.op is Ops.CONST: return None
|
248
|
+
# gather all local aranges (including any fused ones)
|
249
|
+
local_arange: list[UOp] = []
|
250
|
+
def gate_reduce(u):
|
251
|
+
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST: local_arange.append(u)
|
252
|
+
return u.op not in {*ALWAYS_CONTIGUOUS, Ops.REDUCE_AXIS} or u is root
|
253
|
+
toposort = root.toposort(gate=gate_reduce)
|
254
|
+
if not local_arange: return None
|
255
|
+
# fuse the nearest expand child of arange
|
256
|
+
local_children: dict[UOp, list[UOp]] = {}
|
257
|
+
for u in toposort:
|
258
|
+
for s in u.src: local_children.setdefault(s, []).append(u)
|
259
|
+
fuse_rep: dict[UOp, UOp] = {}
|
260
|
+
for r in local_arange:
|
261
|
+
# skip if already fused
|
262
|
+
if len(r.arg) > 2: continue
|
263
|
+
q = list(local_children[r])
|
264
|
+
while q:
|
265
|
+
u = q.pop()
|
266
|
+
if not (curr_children:=local_children.get(u, [])): continue
|
267
|
+
for child in curr_children:
|
268
|
+
other_paths = {s for s in child.toposort() if s.op in {Ops.REDUCE_AXIS, Ops.BUFFER} and s not in {root, r}}
|
269
|
+
fuse_rep[child] = child.replace(src=tuple(s.fuse() if s is u else s for s in child.src))
|
270
|
+
if other_paths: break
|
271
|
+
else: q.extend(curr_children)
|
272
|
+
return root.substitute(fuse_rep, name="fuse_arange") if fuse_rep else None
|
273
|
+
|
274
|
+
do_fuse = PatternMatcher([
|
275
|
+
(UPat(Ops.FUSE, name="x"), do_fusion),
|
276
|
+
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
|
277
|
+
])
|
278
|
+
|
279
|
+
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"),
|
280
|
+
lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
|
281
|
+
|
282
|
+
# TODO: get this from the device through GrouperOpts
|
283
|
+
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
|
284
|
+
|
285
|
+
def limit_bufs(root:UOp):
|
286
|
+
# check if backend has a buffer limit
|
287
|
+
device = root.device if isinstance(root.device, str) else root.device[0].split(":")[0]
|
288
|
+
if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None
|
289
|
+
# count number of unique buffers flowing into this op
|
290
|
+
bufs: set[UOp] = set()
|
291
|
+
def gate_input(u:UOp):
|
292
|
+
if (is_load:=(u.op in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u)
|
293
|
+
return not is_load
|
294
|
+
root.toposort(gate=gate_input)
|
295
|
+
# NOTE: this -1 is for the output buffer
|
296
|
+
if len(bufs)>=MAX_BUFS-1:
|
297
|
+
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).contiguous() for s in root.src))
|
298
|
+
|
299
|
+
def view_add_srcs(x:UOp):
|
300
|
+
if len(avars:=x.arg.vars()) and len(x.src) == 1:
|
301
|
+
return x.replace(src=x.src+tuple(avars))
|
302
|
+
return None
|
303
|
+
|
304
|
+
finalize_contiguous = PatternMatcher([
|
305
|
+
# if an op takes more than one input, check combined LOADs don't exceed device limits
|
306
|
+
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
|
307
|
+
# merge contiguous
|
308
|
+
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]),
|
309
|
+
# simplify views
|
310
|
+
(UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
|
311
|
+
# vars to views srcs
|
312
|
+
(UPat(Ops.VIEW, name="x"), view_add_srcs),
|
313
|
+
])
|
314
|
+
|
315
|
+
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
316
|
+
|
317
|
+
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
|
318
|
+
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
319
|
+
"""
|
320
|
+
Function to transform the Tensor UOp graph into a version with Ops.KERNEL
|
321
|
+
|
322
|
+
Args:
|
323
|
+
sink: The Ops.SINK rooting the Tensor graph.
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
Map transforming each UOp in the sink to the Ops.KERNEL graph.
|
327
|
+
"""
|
328
|
+
|
329
|
+
# multi + merge_views + simplify
|
330
|
+
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
|
331
|
+
|
332
|
+
# display the cleaned up tensor graph
|
333
|
+
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
334
|
+
|
335
|
+
# insert contiguous in places determined by the realize map
|
336
|
+
realize_map = group_realizes(tensor_map[sink])
|
337
|
+
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
|
338
|
+
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
|
339
|
+
|
340
|
+
# group into kernels (this is context-free)
|
341
|
+
tensor_map = graph_rewrite_map(tensor_map[sink], create_kernels, input_map=tensor_map, name="create_kernels")
|
342
|
+
|
343
|
+
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
344
|
+
kernel_assign: dict[UOp, UOp] = {}
|
345
|
+
assign_rep: dict[UOp, UOp] = {}
|
346
|
+
for u in tensor_map[sink].toposort():
|
347
|
+
if u.op is not Ops.ASSIGN: continue
|
348
|
+
kernel_assign[u.buf_uop] = u
|
349
|
+
for s in u.src[1].src:
|
350
|
+
# TODO: this is probably broken for MSELECT/MSTACK
|
351
|
+
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
352
|
+
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
|
353
|
+
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
354
|
+
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
355
|
+
if assign_rep:
|
356
|
+
tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
|
357
|
+
|
358
|
+
# finally, create the AST for kernels
|
359
|
+
tensor_map = graph_rewrite_map(tensor_map[sink], create_ast+replace_metadata, bottom_up=True, input_map=tensor_map, name="create_ast")
|
360
|
+
|
361
|
+
# display the final graph
|
362
|
+
sched_sink = tensor_map[sink]
|
363
|
+
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
364
|
+
|
365
|
+
# verify Kernels match the spec
|
366
|
+
if __debug__: type_verify(list(sched_sink.toposort()), tensor_uop_spec)
|
367
|
+
|
368
|
+
return tensor_map
|