tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +78 -90
- tinygrad/codegen/linearizer.py +237 -169
- tinygrad/codegen/uops.py +278 -242
- tinygrad/device.py +147 -10
- tinygrad/dtype.py +7 -7
- tinygrad/engine/graph.py +16 -16
- tinygrad/engine/jit.py +39 -36
- tinygrad/engine/realize.py +6 -5
- tinygrad/engine/schedule.py +15 -7
- tinygrad/engine/search.py +6 -3
- tinygrad/function.py +17 -23
- tinygrad/helpers.py +77 -8
- tinygrad/lazy.py +26 -26
- tinygrad/multi.py +13 -9
- tinygrad/nn/__init__.py +1 -1
- tinygrad/nn/datasets.py +2 -1
- tinygrad/nn/state.py +3 -4
- tinygrad/ops.py +49 -16
- tinygrad/renderer/__init__.py +8 -4
- tinygrad/renderer/assembly.py +93 -100
- tinygrad/renderer/cstyle.py +47 -42
- tinygrad/renderer/llvmir.py +30 -30
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +11504 -1
- tinygrad/runtime/autogen/comgr.py +36 -10
- tinygrad/runtime/autogen/hsa.py +146 -14
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/nv_gpu.py +269 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +20 -11
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +3 -2
- tinygrad/runtime/graph/cuda.py +2 -2
- tinygrad/runtime/graph/hcq.py +122 -78
- tinygrad/runtime/ops_amd.py +302 -316
- tinygrad/runtime/ops_cuda.py +3 -3
- tinygrad/runtime/ops_disk.py +70 -5
- tinygrad/runtime/ops_gpu.py +2 -2
- tinygrad/runtime/ops_metal.py +5 -6
- tinygrad/runtime/ops_npy.py +1 -1
- tinygrad/runtime/ops_nv.py +161 -166
- tinygrad/runtime/ops_python.py +20 -16
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +5 -2
- tinygrad/shape/symbolic.py +1 -3
- tinygrad/shape/view.py +34 -19
- tinygrad/tensor.py +219 -135
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/runtime/driver/hsa.py +0 -143
- tinygrad/runtime/graph/hsa.py +0 -171
- tinygrad/runtime/ops_hsa.py +0 -278
- tinygrad-0.9.0.dist-info/RECORD +0 -60
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
File without changes
|
tinygrad/codegen/kernel.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
3
|
-
|
4
|
-
from
|
2
|
+
from collections import defaultdict
|
3
|
+
import itertools
|
4
|
+
from typing import DefaultDict, Optional, List, Tuple, cast, Dict, Union
|
5
|
+
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS, verify_lazyop
|
5
6
|
from tinygrad.device import Device
|
6
7
|
from tinygrad.renderer import Renderer, TensorCore
|
7
8
|
from tinygrad.dtype import dtypes, ImageDType, DType
|
8
|
-
from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
|
9
|
+
from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
|
9
10
|
from tinygrad.shape.shapetracker import ShapeTracker
|
10
11
|
from tinygrad.shape.symbolic import sint
|
11
12
|
from tinygrad.shape.view import View, strides_for_shape
|
@@ -34,16 +35,20 @@ class Opt:
|
|
34
35
|
if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
|
35
36
|
return self.axis
|
36
37
|
|
37
|
-
|
38
|
-
|
39
|
-
axes:
|
40
|
-
axes_exist:
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
38
|
+
@dataclass
|
39
|
+
class TensorCoreOptions:
|
40
|
+
axes: Tuple[int, ...] # the location of the original N and M axes if still in the shape
|
41
|
+
axes_exist: Tuple[bool, ...] # true if the original N and M axes are still in the shape
|
42
|
+
axis_pads: Tuple[Tuple[int, int], ...]
|
43
|
+
def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when a dimension is removed
|
44
|
+
axes, axes_exist = list(self.axes), list(self.axes_exist)
|
45
|
+
for tc_dim in [i for i in range(2) if axes_exist[i]]:
|
46
|
+
if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
|
47
|
+
elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
|
48
|
+
self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
|
49
|
+
|
50
|
+
@dataclass(frozen=True)
|
51
|
+
class LocalBuffer:
|
47
52
|
name: str
|
48
53
|
size: int
|
49
54
|
dtype: DType = dtypes.float32
|
@@ -53,24 +58,21 @@ class LocalBuffer(NamedTuple):
|
|
53
58
|
class Kernel:
|
54
59
|
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
|
55
60
|
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
56
|
-
|
57
|
-
assert len(set(op.arg.st.size for op in ast)) == 1, f"all outbufs should have the same size, got {[op.arg.st for op in ast]}"
|
61
|
+
verify_lazyop(*ast)
|
58
62
|
self.ast = ast
|
59
63
|
self.lazyops = flatten([op.lazyops for op in self.ast])
|
60
64
|
|
61
|
-
# there's only allowed to be one reduceop
|
62
65
|
cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
|
63
66
|
def ordered_lazyops(op):
|
64
67
|
if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
|
65
68
|
return cached_ordered_lazyops[op]
|
66
69
|
self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
|
67
|
-
assert len(self.reduceops) < 2, "Only one reduceop allowed"
|
68
70
|
|
69
71
|
self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
|
70
72
|
loadops = [BufferOps.LOAD, BufferOps.CONST]
|
71
73
|
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
|
72
74
|
|
73
|
-
# get earlybufs, before
|
75
|
+
# get earlybufs, before any reduceops
|
74
76
|
self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
|
75
77
|
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
|
76
78
|
|
@@ -87,9 +89,11 @@ class Kernel:
|
|
87
89
|
self.group_for_reduces: int = 0
|
88
90
|
self.upcasted: int = 0
|
89
91
|
self.local_dims: int = 0
|
90
|
-
self.local_alias: Dict[int, LocalBuffer] =
|
92
|
+
self.local_alias: DefaultDict[LazyOp, Dict[int, LocalBuffer]] = defaultdict(dict)
|
91
93
|
self.tensor_core: Optional[TensorCore] = None
|
92
94
|
self.tensor_core_opts: Optional[TensorCoreOptions] = None
|
95
|
+
# the local aliased buffers for A and B
|
96
|
+
self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
|
93
97
|
self.dont_use_locals: bool = False
|
94
98
|
|
95
99
|
# group simplifies
|
@@ -113,7 +117,8 @@ class Kernel:
|
|
113
117
|
# parameters for optimizations
|
114
118
|
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
|
115
119
|
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
|
116
|
-
ret.tensor_core, ret.tensor_core_opts, ret.local_alias = self.tensor_core, self.tensor_core_opts,
|
120
|
+
ret.tensor_core, ret.tensor_core_opts, ret.local_alias, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, defaultdict(dict), \
|
121
|
+
self.bufs_for_tensor_core
|
117
122
|
|
118
123
|
# uncached since linearize didn't run
|
119
124
|
ret.applied_opts_cache = None
|
@@ -256,54 +261,29 @@ class Kernel:
|
|
256
261
|
shapes.append(self.output_shape)
|
257
262
|
strides.append(special_strides)
|
258
263
|
|
259
|
-
# merge dimensions if we can, multi
|
264
|
+
# merge dimensions if we can, multi _merge_dims
|
260
265
|
# NOTE: this does not always preserve the reduce dimension
|
261
266
|
# TODO: move this into shapetracker, with tests!
|
262
|
-
|
267
|
+
# TODO: how does this work with multi-reduce?
|
268
|
+
rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
|
263
269
|
for i in range(1, len(shapes[0])):
|
264
270
|
can_merge = []
|
265
|
-
for
|
271
|
+
for s,st,ret in zip(shapes, strides, rets):
|
266
272
|
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
267
|
-
|
273
|
+
si, sti, last_st = s[i], st[i], ret[-1][1]
|
274
|
+
can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
|
268
275
|
# more can merge than this
|
269
276
|
mergeable = all(can_merge) and i != self.first_reduce
|
270
|
-
for j in
|
271
|
-
if mergeable: rets[j][-1] = (rets[j][-1][0] *
|
272
|
-
else: rets[j].append((
|
277
|
+
for j,(s,st) in enumerate(zip(shapes, strides)):
|
278
|
+
if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
|
279
|
+
else: rets[j].append((s[i], st[i]))
|
273
280
|
|
274
281
|
# do the reshapes
|
275
282
|
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
276
283
|
|
277
284
|
# ******************** helpers ********************
|
278
285
|
|
279
|
-
def
|
280
|
-
new_shape = list(x)
|
281
|
-
for i in range(len(new_shape)):
|
282
|
-
next_idx = (i + 1) % len(new_shape)
|
283
|
-
while new_shape[i] > max_size[i]:
|
284
|
-
# TODO: what if new_shape[i] is not a multiple of 2??
|
285
|
-
new_shape[i] = new_shape[i] // 2
|
286
|
-
next_idx = next_idx if new_shape[next_idx] <= max_size[next_idx] else (next_idx + 1) % len(new_shape)
|
287
|
-
new_shape[next_idx] = new_shape[next_idx] * 2
|
288
|
-
return tuple(new_shape)
|
289
|
-
|
290
|
-
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
|
291
|
-
# Check the global allocation limit, current the global_size will be flipped during codegen
|
292
|
-
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
293
|
-
if self.global_dims > 0:
|
294
|
-
if global_max:
|
295
|
-
tmp = global_max[:self.global_dims] + (local_max[:self.local_dims] if local_max else [])
|
296
|
-
if max(global_max) < max(self.full_shape[:self.global_dims]):
|
297
|
-
self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
298
|
-
assert max(global_max) >= max(self.full_shape[:self.global_dims]), f"device max allocation {max(self.full_shape[:self.global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
|
299
|
-
for i in range(self.global_dims-1):
|
300
|
-
if i < len(global_max) and self.full_shape[i] > global_max[i]:
|
301
|
-
order = list(range(len(self.full_shape)))
|
302
|
-
order[i], order[self.global_dims-1] = order[self.global_dims-1], order[i]
|
303
|
-
self.reshape_and_permute(None, order)
|
304
|
-
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
|
305
|
-
|
306
|
-
def alias_buffer(self, i, pattern):
|
286
|
+
def alias_buffer(self, op:LazyOp, i:int, pattern:List[int]) -> None:
|
307
287
|
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
308
288
|
|
309
289
|
bst = 1
|
@@ -318,60 +298,67 @@ class Kernel:
|
|
318
298
|
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
319
299
|
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
|
320
300
|
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
321
|
-
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
|
301
|
+
self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])
|
322
302
|
|
323
303
|
# ******************** high level optimizers ********************
|
324
304
|
|
305
|
+
def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
|
306
|
+
has_cast = tc.dtype_in != tc.dtype_out
|
307
|
+
if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
|
308
|
+
|
309
|
+
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
310
|
+
if mul_op.op is not BinaryOps.MUL: return None
|
311
|
+
|
312
|
+
def buf_index(src: LazyOp) -> Optional[int]:
|
313
|
+
# TODO: apply tc even if the sources are not from LOAD
|
314
|
+
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
315
|
+
try:
|
316
|
+
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
317
|
+
except ValueError: return None
|
318
|
+
return None
|
319
|
+
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
320
|
+
|
321
|
+
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
322
|
+
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
323
|
+
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
324
|
+
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
|
325
|
+
|
326
|
+
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
327
|
+
if not(axis < len(axis_choices)): return None
|
328
|
+
|
329
|
+
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
330
|
+
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0)
|
331
|
+
if axis_pads and (opt_level < 2): return None
|
332
|
+
self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
|
333
|
+
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
334
|
+
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
|
335
|
+
|
325
336
|
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
326
337
|
if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
|
327
338
|
for tc in self.opts.tensor_cores:
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
if mul_op.op is not BinaryOps.MUL: continue
|
333
|
-
|
334
|
-
def buf_index(src: LazyOp) -> Optional[int]:
|
335
|
-
# TODO: apply tc even if the sources are not from LOAD
|
336
|
-
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
337
|
-
try:
|
338
|
-
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
339
|
-
except ValueError: return None
|
340
|
-
return None
|
341
|
-
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
|
342
|
-
|
343
|
-
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
344
|
-
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
345
|
-
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
346
|
-
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
|
347
|
-
|
348
|
-
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
349
|
-
if not(axis < len(axis_choices)): continue
|
350
|
-
|
351
|
-
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
352
|
-
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
|
353
|
-
if axis_pads and (opt_level < 2): continue
|
354
|
-
|
339
|
+
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
340
|
+
# can only fuse reduces with the same tc options
|
341
|
+
assert all_same(tensor_core_opts)
|
342
|
+
if tensor_core_opts[0] is None: continue
|
355
343
|
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
356
|
-
self.tensor_core_opts =
|
344
|
+
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
357
345
|
|
358
346
|
# attempt to pad the tensor axes that require it
|
359
347
|
try:
|
360
|
-
for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
348
|
+
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
361
349
|
except KernelOptError: continue
|
362
|
-
self.apply_opt(Opt(OptOps.UNROLL,
|
350
|
+
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
|
363
351
|
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
364
352
|
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
365
353
|
for (tc_dim, tc_amt) in tc.threads:
|
366
354
|
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
367
355
|
|
368
356
|
# assert tensor core
|
369
|
-
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
370
357
|
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
371
358
|
return True
|
372
359
|
return False
|
373
360
|
|
374
|
-
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:int=
|
361
|
+
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:Optional[int]=None) -> bool:
|
375
362
|
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
376
363
|
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
377
364
|
|
@@ -386,6 +373,7 @@ class Kernel:
|
|
386
373
|
1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
|
387
374
|
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
388
375
|
"""
|
376
|
+
if tc_opt is None: tc_opt = self.opts.tc_opt
|
389
377
|
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
390
378
|
try: # check TC first and apply hand-coded opts if successful
|
391
379
|
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
@@ -418,7 +406,7 @@ class Kernel:
|
|
418
406
|
if opt.op is OptOps.TC:
|
419
407
|
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
420
408
|
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
|
421
|
-
check((use_tensor_cores:=
|
409
|
+
check((use_tensor_cores:=self.opts.tc) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
422
410
|
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
|
423
411
|
self.applied_opts.append(opt)
|
424
412
|
return
|