tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,13 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict, Callable, Set
2
+ from typing import Optional, Any, Callable
3
3
  import functools, itertools, operator
4
4
  from collections import defaultdict
5
5
  from tinygrad.dtype import dtypes, ImageDType, PtrDType
6
- from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
6
+ from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve
7
7
  from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
8
8
  from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
9
9
  from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
10
-
11
- if TYPE_CHECKING: from tinygrad.renderer import Renderer
10
+ from tinygrad.renderer import Renderer
12
11
 
13
12
  # ***** float4/image store handling *****
14
13
 
@@ -19,7 +18,7 @@ def fold_expanded(ex, buf):
19
18
  is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
20
19
 
21
20
  # first, extract all the relevant offsets
22
- offsets_rootsrc: DefaultDict[Any, dict] = defaultdict(dict)
21
+ offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
23
22
  for i,s in enumerate(new_srcs):
24
23
  idx = s.src[0].src[1]
25
24
  if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
@@ -33,7 +32,7 @@ def fold_expanded(ex, buf):
33
32
 
34
33
  # then rewrite everything we can
35
34
  lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
36
- used: Set[Tuple[UOp, UOp]] = set()
35
+ used: set[tuple[UOp, UOp]] = set()
37
36
  for rootsrc, offsets in offsets_rootsrc.items():
38
37
  for o in offsets:
39
38
  for fold_length in lengths:
@@ -49,7 +48,8 @@ def fold_expanded(ex, buf):
49
48
  rootsrc[0] if isinstance(rootsrc, tuple) else None)
50
49
  else:
51
50
  # for non image, we upcast the index pointer
52
- new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(new_src[0].dtype.local))
51
+ new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size//fold_length,
52
+ local=new_src[0].dtype.local))
53
53
  # generate the folded new_srcs
54
54
  if is_load:
55
55
  new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
@@ -83,7 +83,7 @@ float4_folding = PatternMatcher([
83
83
 
84
84
  # ***** image load valid simplification *****
85
85
 
86
- def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
86
+ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
87
87
  if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
88
88
  if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
89
89
 
@@ -122,19 +122,17 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
122
122
  powers_of_two = {2**i:i for i in range(64)}
123
123
  @functools.lru_cache(None)
124
124
  def get_late_rewrite_patterns(ops, force_transcendental=False):
125
- pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
125
+ pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
126
126
  ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
127
- # rewrite MOD to AND (which should always be supported, but not for generic in tests)
127
+ # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
128
128
  if Ops.AND in ops:
129
- pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
130
- lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
131
- # rewrite MUL/IDIV to SHL+SHR
129
+ pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
130
+ # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
132
131
  if Ops.SHL in ops and Ops.SHR in ops:
133
132
  pat += [
134
- (UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
135
- mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
136
- (UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
137
- div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
133
+ (UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << powers_of_two[c.arg] if c.arg in powers_of_two else None),
134
+ (UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> powers_of_two[c.arg] if c.arg in powers_of_two and resolve(x>=0,False) else None)
135
+ ]
138
136
  if Ops.NEG in ops:
139
137
  pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
140
138
  if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
@@ -191,7 +189,7 @@ def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extr
191
189
 
192
190
  def index_collapse(idx:UOp,rng:UOp,buf:UOp,ld:UOp,acc:UOp,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
193
191
  if rng not in acc.src: return None
194
- new_load = UOp.load(buf.index(add+mul*idx, idx.ge(rng.src[0]) & idx.lt(rng.src[1])), dtype=ld.dtype)
192
+ new_load = UOp.load(buf.index(add+mul*idx, (idx >= rng.src[0]) & (idx < rng.src[1])), dtype=ld.dtype)
195
193
  new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
196
194
  return new_acc.assign(new_acc+new_load)
197
195
 
@@ -221,7 +219,7 @@ def no_vectorized_wmma(wmma:UOp):
221
219
  return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
222
220
 
223
221
  def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
224
- reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents)
222
+ reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort)
225
223
  if len(reduce_unparented) == 0: return None
226
224
  new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
227
225
  ret = new_acc.assign(new_acc.alu(alu.op, ret))
@@ -235,17 +233,18 @@ rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UP
235
233
  index_load = UPat.var("buf").index(rng_aug).load(name="ld")
236
234
 
237
235
  arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
238
- arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
236
+ arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
237
+
238
+ # this moves the accumulation variable down an unrolled add chain which allows for more efficient accumulation using mulacc
239
+ mulacc_unrolled = PatternMatcher([(UPat.var("x")+UPat.var("y")+acc_pat, lambda x,y,acc: (acc+x)+y if y.op is not Ops.DEFINE_ACC else None)])
239
240
 
240
241
  # this is symbolic 2.0
241
242
  sym = symbolic_flat+PatternMatcher([
242
243
  # self ASSIGN is just self
243
244
  (UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
244
- # ASSIGN to global is just self
245
- (UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
246
245
  # VECTORIZE/CONST, VECTORIZE/GEP
247
246
  (UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
248
- (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
247
+ (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
249
248
  # reorder ALU/VECTORIZE
250
249
  (UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
251
250
  lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
@@ -288,14 +287,16 @@ sym = symbolic_flat+PatternMatcher([
288
287
  # indexing, with cast or where
289
288
  (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
290
289
  (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
291
- # parentless reduce
292
- (acc_pat.assign(UPat(Ops.ADD, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
293
- (acc_pat.assign(UPat(Ops.MAX, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
290
+ # parentless reduce # TODO: add MUL
291
+ (acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
294
292
  # ** self folding **
295
293
  (UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
296
294
  (UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
297
295
  # x!=0 -> (bool)x
298
- (UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
296
+ (UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
297
+ # ** where **
298
+ # push cast to branches
299
+ (UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
299
300
  # ** load/store folding **
300
301
  (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
301
302
  (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
@@ -308,31 +309,36 @@ sym = symbolic_flat+PatternMatcher([
308
309
  # remove NOOPs from SINK
309
310
  (UPat(Ops.SINK, name="root"),
310
311
  lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
311
- # remove EXPANDs from SINK/BARRIER
312
+ # remove VECTORIZE from SINK/BARRIER
312
313
  (UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
313
314
  (UPat(Ops.SINK, name="root"),
314
- lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.EXPAND} else (x,) for x in root.src)), root.arg)
315
- if any(x.op in {Ops.SINK, Ops.EXPAND} for x in root.src) else None),
315
+ lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
316
+ if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
317
+ ((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
318
+ ((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
319
+ (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
320
+ (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
321
+ (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
316
322
  ])
317
323
 
318
324
  # *** uop expander ***
319
325
 
320
- def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) -> int:
326
+ def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
321
327
  idx, mul = 0, 1
322
328
  for axis,m in args[::-1]:
323
329
  idx += rpk[axis] * mul
324
330
  mul *= m
325
331
  return idx
326
332
 
327
- def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]:
333
+ def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
328
334
  return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
329
335
 
330
336
  @functools.lru_cache(None)
331
- def _swizzle_args(cargs:Tuple[Tuple[int, int], ...], eargs:Tuple[Tuple[int, int], ...], exclude_args:Tuple[int, ...]) -> List[int]:
337
+ def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
332
338
  return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
333
339
 
334
340
  def do_expand(root:UOp):
335
- expands = [x for x in root.src if x.op is Ops.EXPAND]
341
+ expands = [x for x in root.src if x.op is Ops.UNROLL]
336
342
  if len(expands) == 0: return None
337
343
  # NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
338
344
  exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
@@ -345,7 +351,7 @@ def do_expand(root:UOp):
345
351
  expand_sz = prod([x[1] for x in expand_args])
346
352
  new_srcs = []
347
353
  for i,src in enumerate(root.src):
348
- if src.op is Ops.EXPAND:
354
+ if src.op is Ops.UNROLL:
349
355
  if root.op is Ops.IF and i == 0:
350
356
  # IF means OR on first arg to IF
351
357
  new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
@@ -358,9 +364,9 @@ def do_expand(root:UOp):
358
364
  if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
359
365
  new_srcs.append(src.src[0].gep(tuple(lst)))
360
366
  else:
361
- # non-EXPAND input
367
+ # non-UNROLL input
362
368
  if root.op is Ops.IF:
363
- # for the first arg of IF, just pass them through ignoring EXPANDS
369
+ # for the first arg of IF, just pass them through ignoring UNROLLS
364
370
  new_srcs.append(src)
365
371
  elif src.dtype.count > 1:
366
372
  # put any input dtype > 1 grouped together
@@ -376,25 +382,25 @@ def do_expand(root:UOp):
376
382
  # is this right?
377
383
  new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
378
384
  nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
379
- return UOp(Ops.EXPAND, root.dtype, (nsrc,), expand_args)
385
+ return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
380
386
 
381
387
  def do_contract(con:UOp):
382
388
  ex = con.src[0]
383
- # CONTRACT without EXPAND repeats the element VECTORIZED
384
- if ex.op is not Ops.EXPAND: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
385
- # CONTRACT may remove several axes from EXPAND
389
+ # CONTRACT without UNROLL repeats the element VECTORIZED
390
+ if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
391
+ # CONTRACT may remove several axes from UNROLL
386
392
  assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
387
393
  idxs = []
388
394
  for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
389
395
  idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
390
- return UOp(Ops.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
396
+ return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
391
397
 
392
398
  def no_vectorized_alu(alu):
393
399
  if alu.dtype.vcount == 1: return None
394
400
  alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
395
401
  return UOp(Ops.VECTORIZE, alu.dtype, alus)
396
402
 
397
- def create_gate(root:UOp) -> Optional[UOp]:
403
+ def create_gate(root:UOp) -> UOp|None:
398
404
  @functools.lru_cache(None)
399
405
  def _gate_srcs(u:UOp, gate:UOp) -> UOp:
400
406
  if u.op is Ops.BARRIER: return u
@@ -407,22 +413,22 @@ def create_gate(root:UOp) -> Optional[UOp]:
407
413
 
408
414
  expander = PatternMatcher([
409
415
  # double expand
410
- (UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)),
411
- lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
416
+ (UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
417
+ lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
412
418
  # do expansion
413
419
  (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
414
- Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.EXPAND])), do_expand),
420
+ Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
415
421
  (UPat(Ops.CONTRACT, name="con"), do_contract),
416
422
  # vectorize DEFINE_ACC
417
423
  (UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
418
424
  # BARRIERs aren't actually expanded
419
- (UPat(Ops.BARRIER, src=(UPat(Ops.EXPAND, name="ex"),)),
420
- lambda ex: UOp(Ops.EXPAND, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
421
- # empty EXPAND is NOOP
422
- (UPat(Ops.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x),
423
- # EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU
424
- (UPat(Ops.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
425
- lambda ex,x,y: UOp(Ops.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
425
+ (UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
426
+ lambda ex: UOp(Ops.UNROLL, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
427
+ # empty UNROLL is NOOP
428
+ (UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
429
+ # UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
430
+ (UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
431
+ lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
426
432
  ])
427
433
 
428
434
  def no_vectorized_load_store(ls:UOp):
@@ -446,8 +452,8 @@ devectorize = PatternMatcher([
446
452
  (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
447
453
  ])
448
454
 
449
- def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]:
450
- if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is Ops.IF]: return None
455
+ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
456
+ if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
451
457
  # remove the gate from the index
452
458
  return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
453
459
 
@@ -468,7 +474,7 @@ migrate_indexing = PatternMatcher([
468
474
  (UPat(Ops.STORE, name="root"), create_gate),
469
475
  ])
470
476
 
471
- def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp:
477
+ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
472
478
  # this moves the mask from the indexing to the load/store op for rendering
473
479
  nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
474
480
  return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
@@ -481,8 +487,11 @@ pm_render = PatternMatcher([
481
487
  (UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
482
488
  (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
483
489
  # move masks of loads/stores
484
- (UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
490
+ (UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))),
485
491
  masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
492
+ # gate any stores that aren't gated with ifs
493
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
494
+ lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
486
495
  ])
487
496
 
488
497
  # *** uop graph ***
@@ -498,8 +507,9 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
498
507
  # expand
499
508
  sink = graph_rewrite(sink, sym+expander)
500
509
 
501
- # devectorize + load_store_indexing
502
- sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
510
+ # devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
511
+ sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
512
+ mulacc_unrolled)
503
513
 
504
514
  # final rules for the renderer (without sym)
505
515
  sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
@@ -1,5 +1,4 @@
1
1
  import math
2
- from typing import Tuple
3
2
  from tinygrad.dtype import dtypes, DType
4
3
  from tinygrad.helpers import polyN
5
4
  from tinygrad.ops import UOp
@@ -22,7 +21,7 @@ def shl(x:UOp, y:int) -> UOp: return x * (2**y)
22
21
  def rintk(d:UOp) -> UOp:
23
22
  """round d:float to int away from 0"""
24
23
  out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
25
- return (d + d.lt(0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
24
+ return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
26
25
 
27
26
  def pow2if(q:UOp, float_dtype:DType):
28
27
  """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
@@ -49,7 +48,7 @@ def ldexp2k(d:UOp, e:UOp) -> UOp:
49
48
  assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
50
49
  return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
51
50
 
52
- def frexp(v:UOp) -> Tuple[UOp, UOp]:
51
+ def frexp(v:UOp) -> tuple[UOp, UOp]:
53
52
  """frexp(v) -> (mantissa, exponent) assuming v != 0"""
54
53
  assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
55
54
  # m1 = masks for mantissa, m2 = masks to normalize the mantissa.
@@ -63,7 +62,7 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]:
63
62
  return mantissa, exp
64
63
 
65
64
  # *** reduction algorithms for sine ***
66
- def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
65
+ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
67
66
  """
68
67
  Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
69
68
  39800.0 <= d <= +Inf
@@ -110,9 +109,9 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
110
109
  r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
111
110
 
112
111
  # if fraction >= 0.5, r -= pi/2, q += 1
113
- return f.lt(0.5).where(r, r - math.pi/2), f.lt(0.5).where(q, q + 1)
112
+ return (f<0.5).where(r, r - math.pi/2), (f<0.5).where(q, q + 1)
114
113
 
115
- def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
114
+ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
116
115
  """
117
116
  Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
118
117
  0 <= abs(d) <= 39800.0
@@ -177,14 +176,14 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
177
176
  # mask +-inf/nan as zero
178
177
  x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
179
178
  # x_sign = sign(x)
180
- x_sign = x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
179
+ x_sign = x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
181
180
  x_abs = x * x_sign
182
181
  r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs)
183
182
  if fast: result = sin_poly_small(r, q)
184
183
  else:
185
184
  # Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
186
185
  r_small, q_small = cody_waite_reduction(x_abs)
187
- result = x_abs.lt(switch_over).where(sin_poly_small(r_small, q_small), sin_poly_large(r, q))
186
+ result = (x_abs<switch_over).where(sin_poly_small(r_small, q_small), sin_poly_large(r, q))
188
187
  # adjusts the sign for abs(x)
189
188
  result = result * x_sign
190
189
  # sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
@@ -210,9 +209,9 @@ def xexp2(d:UOp) -> UOp:
210
209
  u = ldexp2k(u, q) # u*2^q
211
210
  upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
212
211
  # Replace x >= upper with +inf
213
- u = d.ge(upper).where(d.const_like(math.inf), u)
214
- # Replace x <= lower with zero.
215
- u = d.lt(lower).where(d.const_like(0.0), u)
212
+ u = (d >= upper).where(d.const_like(math.inf), u)
213
+ # Replace x < lower with zero.
214
+ u = (d<lower).where(d.const_like(0.0), u)
216
215
  # exp2(NaN) = NaN
217
216
  return d.ne(d).where(d.const_like(math.nan), u)
218
217
 
@@ -225,7 +224,7 @@ def xlog2(d:UOp) -> UOp:
225
224
  # TODO: float16 denormal need float32 to achieve precision
226
225
  if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
227
226
  FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
228
- is_denormal = d.lt(FLT_MIN)
227
+ is_denormal = d<FLT_MIN
229
228
  a = is_denormal.where(d * (2 ** 64), d)
230
229
 
231
230
  e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
@@ -246,7 +245,7 @@ def xlog2(d:UOp) -> UOp:
246
245
  # log2(Inf) = Inf
247
246
  r = d.ne(math.inf).where(r, r.const_like(math.inf))
248
247
  # log2(x) = NaN for x < 0
249
- r = d.lt(-0.0).where(r.const_like(math.nan), r)
248
+ r = (d<-0.0).where(r.const_like(math.nan), r)
250
249
  # log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
251
250
  # log2_zero = the value of unmasked xlog2(0.0).
252
251
  log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]