tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
File without changes
@@ -1,11 +1,12 @@
1
1
  from __future__ import annotations
2
- import math, itertools
3
- from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
4
- from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS
2
+ from collections import defaultdict
3
+ import itertools
4
+ from typing import DefaultDict, Optional, List, Tuple, cast, Dict, Union
5
+ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS, verify_lazyop
5
6
  from tinygrad.device import Device
6
7
  from tinygrad.renderer import Renderer, TensorCore
7
8
  from tinygrad.dtype import dtypes, ImageDType, DType
8
- from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
9
+ from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
9
10
  from tinygrad.shape.shapetracker import ShapeTracker
10
11
  from tinygrad.shape.symbolic import sint
11
12
  from tinygrad.shape.view import View, strides_for_shape
@@ -34,16 +35,20 @@ class Opt:
34
35
  if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
35
36
  return self.axis
36
37
 
37
- class TensorCoreOptions(NamedTuple):
38
- bufs: Tuple[int, int] # the local aliased buffers for A and B
39
- axes: List[int] # the location of the original N and M axes if still in the shape
40
- axes_exist: List[bool] # true if the original N and M axes are still in the shape
41
- def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an dimension is removed
42
- for tc_dim in [i for i in range(2) if self.axes_exist[i]]:
43
- if removed_axis < self.axes[tc_dim]: self.axes[tc_dim] -= 1
44
- elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False
45
-
46
- class LocalBuffer(NamedTuple):
38
+ @dataclass
39
+ class TensorCoreOptions:
40
+ axes: Tuple[int, ...] # the location of the original N and M axes if still in the shape
41
+ axes_exist: Tuple[bool, ...] # true if the original N and M axes are still in the shape
42
+ axis_pads: Tuple[Tuple[int, int], ...]
43
+ def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when a dimension is removed
44
+ axes, axes_exist = list(self.axes), list(self.axes_exist)
45
+ for tc_dim in [i for i in range(2) if axes_exist[i]]:
46
+ if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
47
+ elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
48
+ self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
49
+
50
+ @dataclass(frozen=True)
51
+ class LocalBuffer:
47
52
  name: str
48
53
  size: int
49
54
  dtype: DType = dtypes.float32
@@ -53,24 +58,21 @@ class LocalBuffer(NamedTuple):
53
58
  class Kernel:
54
59
  def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
55
60
  self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
56
- assert all(op.op is BufferOps.STORE for op in ast), f"kernels must have stores as the output, got {ast}"
57
- assert len(set(op.arg.st.size for op in ast)) == 1, f"all outbufs should have the same size, got {[op.arg.st for op in ast]}"
61
+ verify_lazyop(*ast)
58
62
  self.ast = ast
59
63
  self.lazyops = flatten([op.lazyops for op in self.ast])
60
64
 
61
- # there's only allowed to be one reduceop
62
65
  cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
63
66
  def ordered_lazyops(op):
64
67
  if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
65
68
  return cached_ordered_lazyops[op]
66
69
  self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
67
- assert len(self.reduceops) < 2, "Only one reduceop allowed"
68
70
 
69
71
  self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
70
72
  loadops = [BufferOps.LOAD, BufferOps.CONST]
71
73
  self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
72
74
 
73
- # get earlybufs, before the one reduce op
75
+ # get earlybufs, before any reduceops
74
76
  self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
75
77
  self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
76
78
 
@@ -87,9 +89,11 @@ class Kernel:
87
89
  self.group_for_reduces: int = 0
88
90
  self.upcasted: int = 0
89
91
  self.local_dims: int = 0
90
- self.local_alias: Dict[int, LocalBuffer] = {}
92
+ self.local_alias: DefaultDict[LazyOp, Dict[int, LocalBuffer]] = defaultdict(dict)
91
93
  self.tensor_core: Optional[TensorCore] = None
92
94
  self.tensor_core_opts: Optional[TensorCoreOptions] = None
95
+ # the local aliased buffers for A and B
96
+ self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
93
97
  self.dont_use_locals: bool = False
94
98
 
95
99
  # group simplifies
@@ -113,7 +117,8 @@ class Kernel:
113
117
  # parameters for optimizations
114
118
  ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
115
119
  self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
116
- ret.tensor_core, ret.tensor_core_opts, ret.local_alias = self.tensor_core, self.tensor_core_opts, {}
120
+ ret.tensor_core, ret.tensor_core_opts, ret.local_alias, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, defaultdict(dict), \
121
+ self.bufs_for_tensor_core
117
122
 
118
123
  # uncached since linearize didn't run
119
124
  ret.applied_opts_cache = None
@@ -256,54 +261,29 @@ class Kernel:
256
261
  shapes.append(self.output_shape)
257
262
  strides.append(special_strides)
258
263
 
259
- # merge dimensions if we can, multi get_shape_strides
264
+ # merge dimensions if we can, multi _merge_dims
260
265
  # NOTE: this does not always preserve the reduce dimension
261
266
  # TODO: move this into shapetracker, with tests!
262
- rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
267
+ # TODO: how does this work with multi-reduce?
268
+ rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
263
269
  for i in range(1, len(shapes[0])):
264
270
  can_merge = []
265
- for j in range(len(shapes)):
271
+ for s,st,ret in zip(shapes, strides, rets):
266
272
  # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
267
- can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) # noqa: E501
273
+ si, sti, last_st = s[i], st[i], ret[-1][1]
274
+ can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
268
275
  # more can merge than this
269
276
  mergeable = all(can_merge) and i != self.first_reduce
270
- for j in range(len(shapes)):
271
- if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
272
- else: rets[j].append((shapes[j][i], strides[j][i]))
277
+ for j,(s,st) in enumerate(zip(shapes, strides)):
278
+ if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
279
+ else: rets[j].append((s[i], st[i]))
273
280
 
274
281
  # do the reshapes
275
282
  for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
276
283
 
277
284
  # ******************** helpers ********************
278
285
 
279
- def _limit_size(self, x: Tuple[int], max_size: List[Union[int,float]]) -> Tuple[int, ...]:
280
- new_shape = list(x)
281
- for i in range(len(new_shape)):
282
- next_idx = (i + 1) % len(new_shape)
283
- while new_shape[i] > max_size[i]:
284
- # TODO: what if new_shape[i] is not a multiple of 2??
285
- new_shape[i] = new_shape[i] // 2
286
- next_idx = next_idx if new_shape[next_idx] <= max_size[next_idx] else (next_idx + 1) % len(new_shape)
287
- new_shape[next_idx] = new_shape[next_idx] * 2
288
- return tuple(new_shape)
289
-
290
- def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
291
- # Check the global allocation limit, current the global_size will be flipped during codegen
292
- # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
293
- if self.global_dims > 0:
294
- if global_max:
295
- tmp = global_max[:self.global_dims] + (local_max[:self.local_dims] if local_max else [])
296
- if max(global_max) < max(self.full_shape[:self.global_dims]):
297
- self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
298
- assert max(global_max) >= max(self.full_shape[:self.global_dims]), f"device max allocation {max(self.full_shape[:self.global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
299
- for i in range(self.global_dims-1):
300
- if i < len(global_max) and self.full_shape[i] > global_max[i]:
301
- order = list(range(len(self.full_shape)))
302
- order[i], order[self.global_dims-1] = order[self.global_dims-1], order[i]
303
- self.reshape_and_permute(None, order)
304
- if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
305
-
306
- def alias_buffer(self, i, pattern):
286
+ def alias_buffer(self, op:LazyOp, i:int, pattern:List[int]) -> None:
307
287
  assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
308
288
 
309
289
  bst = 1
@@ -318,60 +298,67 @@ class Kernel:
318
298
  self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
319
299
  self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
320
300
  if DEBUG >= 4: print("aliasing buffer", self.sts[i])
321
- self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
301
+ self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])
322
302
 
323
303
  # ******************** high level optimizers ********************
324
304
 
305
+ def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
306
+ has_cast = tc.dtype_in != tc.dtype_out
307
+ if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
308
+
309
+ mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
310
+ if mul_op.op is not BinaryOps.MUL: return None
311
+
312
+ def buf_index(src: LazyOp) -> Optional[int]:
313
+ # TODO: apply tc even if the sources are not from LOAD
314
+ if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
315
+ try:
316
+ if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
317
+ except ValueError: return None
318
+ return None
319
+ if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
320
+
321
+ buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
322
+ axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
323
+ axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
324
+ if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
325
+
326
+ axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
327
+ if not(axis < len(axis_choices)): return None
328
+
329
+ s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
330
+ axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0)
331
+ if axis_pads and (opt_level < 2): return None
332
+ self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
333
+ if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
334
+ return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
335
+
325
336
  def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
326
337
  if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
327
338
  for tc in self.opts.tensor_cores:
328
- has_cast = tc.dtype_in != tc.dtype_out
329
- if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
330
-
331
- mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
332
- if mul_op.op is not BinaryOps.MUL: continue
333
-
334
- def buf_index(src: LazyOp) -> Optional[int]:
335
- # TODO: apply tc even if the sources are not from LOAD
336
- if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
337
- try:
338
- if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
339
- except ValueError: return None
340
- return None
341
- if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
342
-
343
- buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
344
- axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
345
- axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
346
- if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
347
-
348
- axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
349
- if not(axis < len(axis_choices)): continue
350
-
351
- s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
352
- axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
353
- if axis_pads and (opt_level < 2): continue
354
-
339
+ tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
340
+ # can only fuse reduces with the same tc options
341
+ assert all_same(tensor_core_opts)
342
+ if tensor_core_opts[0] is None: continue
355
343
  # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
356
- self.tensor_core_opts = (tc_opts:=TensorCoreOptions(bufs=(buf0, buf1), axes=[s0, s1], axes_exist=[True, True]))
344
+ self.tensor_core_opts = tc_opts = tensor_core_opts[0]
357
345
 
358
346
  # attempt to pad the tensor axes that require it
359
347
  try:
360
- for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
348
+ for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
361
349
  except KernelOptError: continue
362
- self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False)
350
+ self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
363
351
  for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
364
352
  if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
365
353
  for (tc_dim, tc_amt) in tc.threads:
366
354
  self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
367
355
 
368
356
  # assert tensor core
369
- if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
370
357
  if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
371
358
  return True
372
359
  return False
373
360
 
374
- def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:int=getenv("TC_OPT")) -> bool:
361
+ def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:Optional[int]=None) -> bool:
375
362
  """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
376
363
  Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
377
364
 
@@ -386,6 +373,7 @@ class Kernel:
386
373
  1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
387
374
  2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
388
375
  """
376
+ if tc_opt is None: tc_opt = self.opts.tc_opt
389
377
  if not self.opts.tensor_cores and use_tensor_cores != 2: return False
390
378
  try: # check TC first and apply hand-coded opts if successful
391
379
  self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
@@ -418,7 +406,7 @@ class Kernel:
418
406
  if opt.op is OptOps.TC:
419
407
  check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
420
408
  check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
421
- check((use_tensor_cores:=getenv("TC", 1)) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
409
+ check((use_tensor_cores:=self.opts.tc) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
422
410
  check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
423
411
  self.applied_opts.append(opt)
424
412
  return