tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,8 @@
1
1
  from typing import Final, Dict, Callable, Any, List, Optional
2
2
  from llvmlite import ir
3
- from tinygrad.codegen.linearizer import UOps, UOp
4
3
  from tinygrad.dtype import DType, PtrDType, dtypes
5
4
  from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
6
- from tinygrad.codegen.uops import UOpGraph
5
+ from tinygrad.codegen.uops import UOps, UOp, UOpGraph
7
6
  from tinygrad.renderer import Renderer
8
7
 
9
8
  MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
@@ -15,14 +14,14 @@ code_for_op: Final[Dict[Op, Callable]] = {
15
14
  (builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
16
15
  UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
17
16
  UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
17
+ UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
18
18
  UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
19
19
  UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
20
20
  BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
21
- BinaryOps.SUB: lambda builder, x, y, dtype: builder.sub(x, y) if dtypes.is_int(dtype) else builder.fsub(x, y, flags=MFLAGS),
22
21
  BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
23
- BinaryOps.DIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y) if dtypes.is_int(dtype) else builder.fdiv(x, y, flags=MFLAGS), # noqa: E501
22
+ BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
24
23
  BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
25
- BinaryOps.CMPEQ: lambda builder, x, y, dtype: builder.icmp_unsigned("==", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("==", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("==", x, y, flags=MFLAGS), # noqa: E501
24
+ BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
26
25
  BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
27
26
  BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
28
27
  BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
@@ -68,16 +67,17 @@ def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
68
67
 
69
68
  class LLVMRenderer(Renderer):
70
69
  device = "LLVM"
71
- supports_float4=False
72
- has_local=False
73
- has_shared=False
70
+ supports_float4 = False
71
+ has_local = False
72
+ has_shared = False
73
+ global_max = None
74
74
 
75
75
  def render(self, name:str, uops:UOpGraph) -> str:
76
76
  # all llvm stuff goes into a module
77
77
  module = ir.Module(name=__file__)
78
78
 
79
79
  # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
80
- buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
80
+ buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
81
81
  buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
82
82
 
83
83
  # create llvm function
@@ -100,21 +100,21 @@ class LLVMRenderer(Renderer):
100
100
  if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
101
101
 
102
102
  for u in uops:
103
- uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
103
+ uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
104
104
  if uop is UOps.STORE:
105
- element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
106
- if len(vin) > 3:
107
- with bb[-1].if_then(lvars[vin[3]]):
108
- bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
105
+ element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
106
+ if len(src) > 3:
107
+ with bb[-1].if_then(lvars[src[3]]):
108
+ bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
109
109
  else:
110
- bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
110
+ bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
111
111
  elif uop is UOps.ENDRANGE:
112
112
  loop_entry_bb, phis = loop_blocks.pop()
113
- idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
114
- lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
113
+ idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
114
+ lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
115
115
  for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
116
116
  bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
117
- bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
117
+ bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
118
118
  else:
119
119
  assert dtype is not None, f"None dtype for uop {uop}"
120
120
  if uop is UOps.RANGE:
@@ -129,28 +129,28 @@ class LLVMRenderer(Renderer):
129
129
  phis.append((rp, lvars[rp]))
130
130
 
131
131
  lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
132
- lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
132
+ lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
133
133
  loop_blocks.append((bb[-1].block, phis))
134
134
  elif uop is UOps.DEFINE_ACC:
135
- lvars[u] = const(args[0], dtype)
135
+ lvars[u] = const(src[0].arg, dtype)
136
136
  reduce_phis.append(u)
137
137
  elif uop is UOps.LOAD:
138
- if len(vin) > 2:
139
- aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
140
- val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
141
- val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
138
+ if len(src) > 2:
139
+ aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
140
+ val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
141
+ val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
142
142
  else:
143
- val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
143
+ val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
144
144
  lvars[u] = val
145
145
  elif uop is UOps.PHI:
146
- lvars[u] = lvars[vin[1]]
146
+ lvars[u] = lvars[src[1]]
147
147
  # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
148
- backward = vin[0]
149
- while backward.uop is UOps.PHI: backward = backward.vin[0]
148
+ backward = src[0]
149
+ while backward.op is UOps.PHI: backward = backward.src[0]
150
150
  lvars[backward] = lvars[u]
151
151
  elif uop is UOps.ALU:
152
- lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
153
- elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
152
+ lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
153
+ elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
154
154
  elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
155
155
  elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
156
156
  elif uop is UOps.CONST: lvars[u] = const(args, dtype)
File without changes