tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/kernel.py
CHANGED
@@ -1,52 +1,47 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
2
|
+
import math, itertools
|
3
3
|
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
4
|
-
from tinygrad.ops import LazyOp,
|
5
|
-
from tinygrad.device import Device
|
4
|
+
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS
|
5
|
+
from tinygrad.device import Device
|
6
|
+
from tinygrad.renderer import Renderer, TensorCore
|
6
7
|
from tinygrad.dtype import dtypes, ImageDType, DType
|
7
|
-
from tinygrad.helpers import
|
8
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
8
|
+
from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
|
9
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
9
10
|
from tinygrad.shape.symbolic import sint
|
10
11
|
from tinygrad.shape.view import View, strides_for_shape
|
11
12
|
from dataclasses import dataclass
|
12
13
|
from enum import Enum, auto
|
13
14
|
|
14
15
|
class OptOps(Enum):
|
15
|
-
|
16
|
+
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
16
17
|
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
|
17
18
|
def __lt__(self, x:OptOps): return self.value < x.value
|
18
19
|
|
20
|
+
class KernelOptError(Exception): pass
|
21
|
+
|
22
|
+
def check(cond:bool, msg:str=""):
|
23
|
+
if not cond: raise KernelOptError(msg)
|
24
|
+
|
19
25
|
@dataclass(frozen=True, order=True)
|
20
26
|
class Opt:
|
21
27
|
op: OptOps
|
22
28
|
axis: Optional[int] = None
|
23
29
|
amt: Optional[int] = None
|
24
30
|
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
tensor_cores: Dict[str, List[TensorCore]] = {
|
40
|
-
"METAL": [
|
41
|
-
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
42
|
-
# TODO: enable half @ half -> half tensor core with correct dtypes in uop
|
43
|
-
# TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
44
|
-
],
|
45
|
-
"HIP": [
|
46
|
-
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
47
|
-
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
48
|
-
]
|
49
|
-
}
|
31
|
+
def real_axis(self, k:Kernel):
|
32
|
+
if self.axis is None: return -1
|
33
|
+
if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
|
34
|
+
if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
|
35
|
+
return self.axis
|
36
|
+
|
37
|
+
class TensorCoreOptions(NamedTuple):
|
38
|
+
bufs: Tuple[int, int] # the local aliased buffers for A and B
|
39
|
+
axes: List[int] # the location of the original N and M axes if still in the shape
|
40
|
+
axes_exist: List[bool] # true if the original N and M axes are still in the shape
|
41
|
+
def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an dimension is removed
|
42
|
+
for tc_dim in [i for i in range(2) if self.axes_exist[i]]:
|
43
|
+
if removed_axis < self.axes[tc_dim]: self.axes[tc_dim] -= 1
|
44
|
+
elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False
|
50
45
|
|
51
46
|
class LocalBuffer(NamedTuple):
|
52
47
|
name: str
|
@@ -55,53 +50,46 @@ class LocalBuffer(NamedTuple):
|
|
55
50
|
realized: None = None
|
56
51
|
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
57
52
|
|
58
|
-
class LinearizerOptions(NamedTuple):
|
59
|
-
device: str = ""
|
60
|
-
# TODO: make this generic with a list of supported types
|
61
|
-
supports_float4: bool = True
|
62
|
-
supports_float4_alu: bool = True
|
63
|
-
has_local: bool = True
|
64
|
-
has_shared: bool = True
|
65
|
-
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
66
|
-
global_max: Optional[List[int]] = None
|
67
|
-
local_max: Optional[List[int]] = None
|
68
|
-
|
69
53
|
class Kernel:
|
70
|
-
def __init__(self, ast:LazyOp, opts:Optional[
|
71
|
-
self.opts = opts
|
54
|
+
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
|
55
|
+
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
56
|
+
assert all(op.op is BufferOps.STORE for op in ast), f"kernels must have stores as the output, got {ast}"
|
57
|
+
assert len(set(op.arg.st.size for op in ast)) == 1, f"all outbufs should have the same size, got {[op.arg.st for op in ast]}"
|
72
58
|
self.ast = ast
|
73
|
-
|
74
|
-
|
75
|
-
# fetch lazyop info
|
76
|
-
self.info: FlopCounter = get_lazyop_info(self.ast)
|
59
|
+
self.lazyops = flatten([op.lazyops for op in self.ast])
|
77
60
|
|
78
61
|
# there's only allowed to be one reduceop
|
79
|
-
|
80
|
-
|
81
|
-
|
62
|
+
cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
|
63
|
+
def ordered_lazyops(op):
|
64
|
+
if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
|
65
|
+
return cached_ordered_lazyops[op]
|
66
|
+
self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
|
67
|
+
assert len(self.reduceops) < 2, "Only one reduceop allowed"
|
82
68
|
|
83
|
-
self.
|
84
|
-
|
69
|
+
self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
|
70
|
+
loadops = [BufferOps.LOAD, BufferOps.CONST]
|
71
|
+
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
|
85
72
|
|
86
73
|
# get earlybufs, before the one reduce op
|
87
|
-
self.earlybufs = [x.arg for
|
74
|
+
self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
|
88
75
|
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
|
89
76
|
|
90
77
|
# create new shapetrackers inside this kernel, we will permute them
|
91
78
|
self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
|
92
79
|
|
93
80
|
# move all reduce axes to the end
|
94
|
-
reduce = list(enumerate(zip(self.full_shape, self.
|
81
|
+
reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
|
95
82
|
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
96
83
|
self.reshape_and_permute(None, permute)
|
97
84
|
|
98
85
|
# parameters for optimization
|
99
86
|
self.applied_opts: List[Opt] = []
|
100
|
-
self.
|
87
|
+
self.group_for_reduces: int = 0
|
101
88
|
self.upcasted: int = 0
|
102
89
|
self.local_dims: int = 0
|
103
90
|
self.local_alias: Dict[int, LocalBuffer] = {}
|
104
91
|
self.tensor_core: Optional[TensorCore] = None
|
92
|
+
self.tensor_core_opts: Optional[TensorCoreOptions] = None
|
105
93
|
self.dont_use_locals: bool = False
|
106
94
|
|
107
95
|
# group simplifies
|
@@ -115,16 +103,17 @@ class Kernel:
|
|
115
103
|
ret = type(self).__new__(type(self))
|
116
104
|
|
117
105
|
# base linearizer params
|
118
|
-
ret.opts, ret.ast = self.opts, self.ast
|
106
|
+
ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
|
119
107
|
|
120
108
|
# things downstream of the AST
|
121
|
-
|
122
|
-
|
123
|
-
|
109
|
+
ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
|
110
|
+
self.reduceops, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
|
111
|
+
ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
|
124
112
|
|
125
113
|
# parameters for optimizations
|
126
|
-
ret.applied_opts, ret.
|
127
|
-
self.applied_opts[:], self.
|
114
|
+
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
|
115
|
+
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
|
116
|
+
ret.tensor_core, ret.tensor_core_opts, ret.local_alias = self.tensor_core, self.tensor_core_opts, {}
|
128
117
|
|
129
118
|
# uncached since linearize didn't run
|
130
119
|
ret.applied_opts_cache = None
|
@@ -138,9 +127,10 @@ class Kernel:
|
|
138
127
|
def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501
|
139
128
|
def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
140
129
|
|
141
|
-
def upcasted_axis(self, i:int):
|
142
|
-
|
143
|
-
|
130
|
+
def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
|
131
|
+
upcasted_shape, upcasted_stride = self.sts[i].shape[self.shape_len-self.upcasted:], self.sts[i].real_strides()[self.shape_len-self.upcasted:]
|
132
|
+
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
|
133
|
+
return list(zip(upcasted_shape, upcasted_stride,
|
144
134
|
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
|
145
135
|
|
146
136
|
# TODO: is there a better way to write this?
|
@@ -158,6 +148,9 @@ class Kernel:
|
|
158
148
|
def first_reduce(self) -> int:
|
159
149
|
return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501
|
160
150
|
|
151
|
+
@property
|
152
|
+
def reduceop(self) -> Optional[LazyOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
153
|
+
|
161
154
|
@property
|
162
155
|
def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
|
163
156
|
|
@@ -172,7 +165,7 @@ class Kernel:
|
|
172
165
|
|
173
166
|
@property
|
174
167
|
def upcast_in_mid_reduce_axes(self) -> List[int]:
|
175
|
-
return [j for j in range(self.first_reduce, self.first_reduce+
|
168
|
+
return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
|
176
169
|
|
177
170
|
@property
|
178
171
|
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
@@ -192,10 +185,10 @@ class Kernel:
|
|
192
185
|
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
|
193
186
|
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
194
187
|
colors += ["cyan"] * self.local_dims
|
195
|
-
# between first_reduce and first_reduce +
|
196
|
-
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce +
|
197
|
-
# between first_reduce +
|
198
|
-
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce +
|
188
|
+
# between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
|
189
|
+
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
|
190
|
+
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
|
191
|
+
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces))
|
199
192
|
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
200
193
|
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
|
201
194
|
assert len(colors) == self.shape_len, "colors size mismatch"
|
@@ -219,7 +212,7 @@ class Kernel:
|
|
219
212
|
|
220
213
|
# drops the final dimension
|
221
214
|
def upcast(self):
|
222
|
-
|
215
|
+
check(self.full_shape[-1] != 1, "can't upcast a dimension with size 1")
|
223
216
|
self.upcasted += 1
|
224
217
|
|
225
218
|
# axis : the axis to pull from
|
@@ -242,7 +235,7 @@ class Kernel:
|
|
242
235
|
if self.shape_len == 0: return False
|
243
236
|
all_ones = [s==1 for s in self.full_shape]
|
244
237
|
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
245
|
-
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
|
238
|
+
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) # TODO: no necessary since upcasted axis can't be un-upcasted
|
246
239
|
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
247
240
|
return any(all_ones)
|
248
241
|
|
@@ -254,7 +247,7 @@ class Kernel:
|
|
254
247
|
if isinstance(self.bufs[0].dtype, ImageDType):
|
255
248
|
base_shape = self.bufs[0].dtype.shape
|
256
249
|
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
257
|
-
special_strides: Tuple[
|
250
|
+
special_strides: Tuple[sint, ...] = tuple()
|
258
251
|
for i,g in enumerate(shape_idx_groups):
|
259
252
|
shape_piece = tuple(self.output_shape[x] for x in g)
|
260
253
|
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
@@ -281,35 +274,32 @@ class Kernel:
|
|
281
274
|
# do the reshapes
|
282
275
|
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
283
276
|
|
284
|
-
# ********************
|
277
|
+
# ******************** helpers ********************
|
285
278
|
|
286
|
-
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
287
|
-
new_shape
|
288
|
-
for i in range(
|
289
|
-
next_idx = (i + 1) %
|
279
|
+
def _limit_size(self, x: Tuple[int], max_size: List[Union[int,float]]) -> Tuple[int, ...]:
|
280
|
+
new_shape = list(x)
|
281
|
+
for i in range(len(new_shape)):
|
282
|
+
next_idx = (i + 1) % len(new_shape)
|
290
283
|
while new_shape[i] > max_size[i]:
|
284
|
+
# TODO: what if new_shape[i] is not a multiple of 2??
|
291
285
|
new_shape[i] = new_shape[i] // 2
|
292
|
-
if
|
293
|
-
|
294
|
-
else:
|
295
|
-
next_idx = (next_idx + 1) % dims
|
296
|
-
new_shape[next_idx] = new_shape[next_idx] * 2
|
286
|
+
next_idx = next_idx if new_shape[next_idx] <= max_size[next_idx] else (next_idx + 1) % len(new_shape)
|
287
|
+
new_shape[next_idx] = new_shape[next_idx] * 2
|
297
288
|
return tuple(new_shape)
|
298
289
|
|
299
290
|
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
|
300
291
|
# Check the global allocation limit, current the global_size will be flipped during codegen
|
301
292
|
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
302
|
-
global_dims
|
303
|
-
if global_dims > 0:
|
293
|
+
if self.global_dims > 0:
|
304
294
|
if global_max:
|
305
|
-
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
|
306
|
-
if max(global_max) < max(self.full_shape[:global_dims]):
|
295
|
+
tmp = global_max[:self.global_dims] + (local_max[:self.local_dims] if local_max else [])
|
296
|
+
if max(global_max) < max(self.full_shape[:self.global_dims]):
|
307
297
|
self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
308
|
-
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
|
309
|
-
for i in range(global_dims-1):
|
298
|
+
assert max(global_max) >= max(self.full_shape[:self.global_dims]), f"device max allocation {max(self.full_shape[:self.global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
|
299
|
+
for i in range(self.global_dims-1):
|
310
300
|
if i < len(global_max) and self.full_shape[i] > global_max[i]:
|
311
301
|
order = list(range(len(self.full_shape)))
|
312
|
-
order[i], order[global_dims-1] = order[global_dims-1], order[i]
|
302
|
+
order[i], order[self.global_dims-1] = order[self.global_dims-1], order[i]
|
313
303
|
self.reshape_and_permute(None, order)
|
314
304
|
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
|
315
305
|
|
@@ -332,134 +322,182 @@ class Kernel:
|
|
332
322
|
|
333
323
|
# ******************** high level optimizers ********************
|
334
324
|
|
335
|
-
def
|
336
|
-
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op
|
337
|
-
for tc in
|
338
|
-
if not (use_tensor_cores==2 or (tc.arch is None or tc.arch == os.uname().machine)): continue
|
325
|
+
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
326
|
+
if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
|
327
|
+
for tc in self.opts.tensor_cores:
|
339
328
|
has_cast = tc.dtype_in != tc.dtype_out
|
329
|
+
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
|
340
330
|
|
341
|
-
if has_cast and not(self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
|
342
331
|
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
332
|
+
if mul_op.op is not BinaryOps.MUL: continue
|
343
333
|
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
334
|
+
def buf_index(src: LazyOp) -> Optional[int]:
|
335
|
+
# TODO: apply tc even if the sources are not from LOAD
|
336
|
+
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
337
|
+
try:
|
338
|
+
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
339
|
+
except ValueError: return None
|
340
|
+
return None
|
341
|
+
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
|
351
342
|
|
352
|
-
|
343
|
+
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
344
|
+
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
345
|
+
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
346
|
+
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
|
353
347
|
|
354
|
-
|
348
|
+
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
349
|
+
if not(axis < len(axis_choices)): continue
|
355
350
|
|
356
|
-
s0, s1 =
|
357
|
-
|
358
|
-
|
359
|
-
def fix(needed, ax):
|
360
|
-
nonlocal s0, s1, s0_exists, s1_exists
|
361
|
-
if not needed: return
|
362
|
-
if s0_exists and ax == s0:
|
363
|
-
if s1_exists and s0 < s1: s1 -= 1
|
364
|
-
s0_exists = False
|
365
|
-
elif s1_exists and ax == s1:
|
366
|
-
if s0_exists and s1 < s0: s0 -= 1
|
367
|
-
s1_exists = False
|
351
|
+
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
352
|
+
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
|
353
|
+
if axis_pads and (opt_level < 2): continue
|
368
354
|
|
369
355
|
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
370
|
-
self.
|
371
|
-
|
356
|
+
self.tensor_core_opts = (tc_opts:=TensorCoreOptions(bufs=(buf0, buf1), axes=[s0, s1], axes_exist=[True, True]))
|
357
|
+
|
358
|
+
# attempt to pad the tensor axes that require it
|
359
|
+
try:
|
360
|
+
for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
361
|
+
except KernelOptError: continue
|
362
|
+
self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False)
|
363
|
+
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
364
|
+
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
372
365
|
for (tc_dim, tc_amt) in tc.threads:
|
373
|
-
|
366
|
+
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
374
367
|
|
375
|
-
# assert tensor core
|
368
|
+
# assert tensor core
|
369
|
+
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
376
370
|
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
371
|
+
return True
|
372
|
+
return False
|
377
373
|
|
374
|
+
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:int=getenv("TC_OPT")) -> bool:
|
375
|
+
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
376
|
+
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
377
|
+
|
378
|
+
Keyword arguments:
|
379
|
+
use_tensor_cores -- controls how tensor cores are applied (default 1)
|
380
|
+
0: will disable any tensor core matching
|
381
|
+
1: enable tensor cores
|
382
|
+
2: apply tensor core shape but don't use UOp.WMMA
|
383
|
+
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
384
|
+
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
385
|
+
0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL
|
386
|
+
1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
|
387
|
+
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
388
|
+
"""
|
389
|
+
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
390
|
+
try: # check TC first and apply hand-coded opts if successful
|
391
|
+
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
392
|
+
|
393
|
+
if (tc_opts:=self.tensor_core_opts) is not None:
|
378
394
|
if extra_opts is not None:
|
379
395
|
for opt in extra_opts: self.apply_opt(opt)
|
380
396
|
else:
|
381
397
|
# hand-coded TC opts
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
398
|
+
def late_upcast_tc(tc_dim: int):
|
399
|
+
if tc_opts.axes_exist[tc_dim]:
|
400
|
+
ax_div = [upc for upc in [5,4,3,2,1] if self.full_shape[tc_opts.axes[tc_dim]]%upc == 0][0]
|
401
|
+
if ax_div != 1: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], ax_div))
|
402
|
+
late_upcast_tc(1) # attempt to upcast M
|
403
|
+
late_upcast_tc(0) # attempt to upcast N
|
404
|
+
|
405
|
+
if self.tensor_core and tc_opts.axes_exist[0]: # attempt to local N
|
389
406
|
for upc in [4,2]:
|
390
|
-
if self.full_shape[
|
391
|
-
self.apply_opt(Opt(OptOps.
|
407
|
+
if self.full_shape[tc_opts.axes[0]] % upc == 0:
|
408
|
+
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
|
392
409
|
break
|
393
410
|
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
411
|
+
return True
|
412
|
+
except KernelOptError:
|
413
|
+
return False
|
414
|
+
|
415
|
+
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
416
|
+
check(not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
|
417
|
+
|
418
|
+
if opt.op is OptOps.TC:
|
419
|
+
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
420
|
+
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
|
421
|
+
check((use_tensor_cores:=getenv("TC", 1)) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
422
|
+
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
|
423
|
+
self.applied_opts.append(opt)
|
424
|
+
return
|
425
|
+
|
426
|
+
axis = opt.real_axis(self)
|
427
|
+
check(axis < len(self.full_shape), "invalid axis")
|
400
428
|
|
401
|
-
def apply_opt(self, opt:Opt):
|
402
|
-
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501
|
403
|
-
self.applied_opts.append(opt)
|
404
|
-
if opt.axis is not None:
|
405
|
-
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
|
406
|
-
else:
|
407
|
-
axis = -1
|
408
429
|
if opt.amt is not None:
|
409
430
|
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
410
|
-
|
411
|
-
if opt.op
|
412
|
-
else:
|
413
|
-
|
414
|
-
if opt.op in
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
431
|
+
check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
|
432
|
+
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
|
433
|
+
else: amt = -1
|
434
|
+
|
435
|
+
if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
436
|
+
acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
|
437
|
+
upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
|
438
|
+
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
|
439
|
+
smem_sz = amt*acc_sz*upcast_sz*local_sz
|
440
|
+
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
441
|
+
|
442
|
+
if opt.op is OptOps.LOCAL: # cyan
|
443
|
+
check(self.opts.has_local, "target does not support local")
|
444
|
+
check(axis < self.global_dims, "local is for globals")
|
445
|
+
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
422
446
|
self.local_dims += 1
|
423
|
-
elif opt.op in
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
self.shift_to(axis, amt, top=(opt.op
|
428
|
-
self.
|
429
|
-
elif opt.op
|
430
|
-
|
431
|
-
|
447
|
+
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
|
448
|
+
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
449
|
+
check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group")
|
450
|
+
check(not self.tensor_core, "can't group with tensor cores")
|
451
|
+
self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
452
|
+
self.group_for_reduces += 1
|
453
|
+
elif opt.op is OptOps.UNROLL: # purple
|
454
|
+
check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted")
|
455
|
+
check(amt <= 32, "don't unroll more than 32")
|
456
|
+
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
|
457
|
+
#upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
|
458
|
+
#self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
|
459
|
+
if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
|
460
|
+
if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
|
432
461
|
self.shift_to(axis, amt, insert_before=None)
|
433
462
|
self.upcast()
|
434
|
-
elif opt.op
|
435
|
-
|
436
|
-
|
463
|
+
elif opt.op is OptOps.UPCAST: # yellow
|
464
|
+
check(axis < self.first_reduce, "upcast is for non-reduce")
|
465
|
+
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
466
|
+
check(amt <= 8, "don't upcast more than 8")
|
437
467
|
self.shift_to(axis, amt, insert_before=None)
|
438
468
|
self.upcast()
|
439
|
-
elif opt.op
|
440
|
-
|
469
|
+
elif opt.op is OptOps.UPCASTMID: # white
|
470
|
+
check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
441
471
|
axes = self.sts[0].unit_stride_axes()
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
self.shift_to(axis, amt, insert_before=self.first_reduce +
|
446
|
-
self.
|
447
|
-
elif opt.op
|
448
|
-
|
449
|
-
|
472
|
+
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
|
473
|
+
check(axes[0] == axis, "wrong axis")
|
474
|
+
check(amt == 4, "don't upcast mid anything but 4")
|
475
|
+
self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
|
476
|
+
self.group_for_reduces += 1
|
477
|
+
elif opt.op is OptOps.NOLOCALS:
|
478
|
+
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
479
|
+
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
450
480
|
self.dont_use_locals = True
|
451
|
-
elif opt.op
|
452
|
-
|
453
|
-
|
481
|
+
elif opt.op is OptOps.PADTO:
|
482
|
+
check(not self.vars, "does not work with symbolic shape")
|
483
|
+
check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
|
484
|
+
# ok to pad SUM if all parent ops have f(0) = 0
|
485
|
+
if self.first_reduce <= axis:
|
486
|
+
check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
|
487
|
+
all(op.op not in UNSAFE_PAD_OPS for ops in r.src for op in ops.lazyops), "cannot pad")
|
454
488
|
padded = False
|
455
489
|
for i,st in enumerate(self.sts):
|
456
|
-
|
457
|
-
|
490
|
+
if self.sts[i].shape[axis] == 1: continue # reduced
|
491
|
+
check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
|
492
|
+
if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
|
458
493
|
# pad right seems to be faster
|
459
494
|
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
|
460
495
|
padded = True
|
461
|
-
|
462
|
-
|
496
|
+
check(padded, "nothing was padded")
|
497
|
+
|
498
|
+
if append_opt: self.applied_opts.append(opt)
|
499
|
+
if self.simplify_ones() and self.tensor_core_opts:
|
500
|
+
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
463
501
|
|
464
502
|
def required_optimizations(self):
|
465
503
|
if self.bufs[0].dtype.__class__ is ImageDType:
|
@@ -474,8 +512,8 @@ class Kernel:
|
|
474
512
|
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
475
513
|
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
476
514
|
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
477
|
-
self.reduceop and self.reduceop.op
|
478
|
-
(mulop:=self.reduceop.src[0]).op
|
515
|
+
self.reduceop is not None and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
516
|
+
(mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
|
479
517
|
st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)]
|
480
518
|
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
481
519
|
def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides))
|
@@ -495,11 +533,13 @@ class Kernel:
|
|
495
533
|
# TODO: use 1024 if it's allowed in a smarter way
|
496
534
|
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
497
535
|
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
498
|
-
|
499
|
-
|
536
|
+
try: # may fail due to excessive smem usage
|
537
|
+
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
538
|
+
break
|
539
|
+
except KernelOptError: pass
|
500
540
|
|
501
541
|
# are we upcasting in mid reduce? (only for images)
|
502
|
-
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.
|
542
|
+
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
503
543
|
axes = self.sts[0].unit_stride_axes()
|
504
544
|
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
505
545
|
if self.sts[0].shape[axes[0]]%4 == 0:
|
@@ -517,7 +557,7 @@ class Kernel:
|
|
517
557
|
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
518
558
|
|
519
559
|
# no more opt if we are grouping
|
520
|
-
if self.
|
560
|
+
if self.group_for_reduces: return
|
521
561
|
|
522
562
|
# **** below this line need to be optional and benchmarked ****
|
523
563
|
|
@@ -574,7 +614,7 @@ class Kernel:
|
|
574
614
|
# **** local groups ****
|
575
615
|
|
576
616
|
if self.opts.has_local:
|
577
|
-
if getenv("NOLOCALS") and self.local_dims == 0 and not self.
|
617
|
+
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
|
578
618
|
self.apply_opt(Opt(OptOps.NOLOCALS))
|
579
619
|
else:
|
580
620
|
# prioritize making expand axes local
|