tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,8 @@
1
- from typing import Optional, List, Tuple, Dict
1
+ from typing import Optional, List, Tuple, Dict, Callable, Any
2
2
  import functools
3
- from dataclasses import dataclass
4
- from tinygrad.helpers import getenv, to_function_name
5
- from tinygrad.codegen.uops import UOpGraph
6
- from tinygrad.shape.symbolic import sym_infer, sint, Variable
3
+ from dataclasses import dataclass, field
4
+ from tinygrad.helpers import to_function_name, dedup, prod
5
+ from tinygrad.ops import Ops, UOp, flops_mem, sym_infer, sint, Variable
7
6
  from tinygrad.dtype import DType
8
7
 
9
8
  @dataclass(frozen=True)
@@ -12,30 +11,56 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
12
11
  dtype_in: DType # dtype for A and B
13
12
  dtype_out: DType # dtype for C and D
14
13
  threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
15
- thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
16
- thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
14
+ reduce_axes: List[Tuple[int,int]] # list of (TC dim,amt) that constructs the shape of the reduce dim
15
+ @property
16
+ def early_upcast_axes(self) -> List[Tuple[int,int]]: # list of (TC dim,amt) that upcasts the threads remainders of dims [0,1]
17
+ return [(d,self.dims[d]//sz) for d,sz in [(dim,prod(sz for d,sz in self.threads if d==dim)) for dim in range(2)] if self.dims[d]>sz]
18
+ upcast_axes: Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]]] # list of (TC dim,amt) that upcast A, B and C
19
+ st1_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for A
20
+ st2_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for B
21
+ expanded_shape: Optional[Tuple[int, ...]] = None
22
+ opts_seq: Tuple[str,str] = ("UP","LC") # upcast input, local the thread pattern
17
23
  def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
18
- def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
19
24
 
20
- @dataclass(frozen=True)
25
+ @dataclass
21
26
  class Program:
22
27
  name:str
23
28
  src:str
24
29
  dname:str
30
+ uops:Optional[List[UOp]]=None
31
+ mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
32
+
33
+ # filled in from uops (if we have uops)
25
34
  global_size:Optional[List[int]]=None
26
35
  local_size:Optional[List[int]]=None
27
- uops:Optional[UOpGraph]=None
28
- op_estimate:sint=0
29
- mem_estimate:sint=0
36
+ vars:List[Variable]=field(default_factory=list)
37
+ globals:List[int]=field(default_factory=list)
38
+ outs:List[int]=field(default_factory=list)
39
+ _ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
30
40
 
31
- @functools.cached_property
32
- def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
41
+ def __post_init__(self):
42
+ if not self._ran_post_init and self.uops is not None:
43
+ # single pass through the uops
44
+ for u in self.uops:
45
+ if u.op is Ops.DEFINE_VAR: self.vars.append(u)
46
+ if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
47
+ if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL])
48
+ if u.op is Ops.SPECIAL:
49
+ # NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
50
+ if u.arg[0][0] == 'i': self.local_size = None
51
+ special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
52
+ assert special_size is not None
53
+ special_size[int(u.arg[0][-1])] = u.arg[1]
54
+ self.vars = sorted(self.vars, key=lambda v: v.arg)
55
+ self.outs = sorted(dedup(self.outs))
56
+ self._ran_post_init = True
33
57
 
58
+ @property
59
+ def op_estimate(self) -> sint: return self._ops_lds[0]
60
+ @property
61
+ def lds_estimate(self) -> sint: return self._ops_lds[1]
34
62
  @functools.cached_property
35
- def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
36
-
37
- @functools.cached_property
38
- def outcount(self) -> int: return sum(x[1] for x in self.globals)
63
+ def _ops_lds(self) -> Tuple[sint, sint]: return (0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True)
39
64
 
40
65
  @functools.cached_property
41
66
  def function_name(self) -> str: return to_function_name(self.name)
@@ -57,9 +82,8 @@ class Renderer:
57
82
  local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
58
83
  shared_max: int = 32768
59
84
  tensor_cores: List[TensorCore] = []
60
- @functools.cached_property
61
- def tc_opt(self): return getenv("TC_OPT")
62
- @functools.cached_property
63
- def tc(self): return getenv("TC", 1)
85
+ extra_matcher: Any = None
86
+ code_for_op: Dict[Ops, Callable] = {}
64
87
 
65
- def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")
88
+ def __reduce__(self): return self.__class__, ()
89
+ def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")