tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/dtype.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
- from typing import Final, Optional, ClassVar, Union, Callable, Literal
2
+ from typing import Final, ClassVar, Callable, Literal
3
3
  import math, struct, ctypes, functools
4
4
  from dataclasses import dataclass, fields
5
5
  from tinygrad.helpers import getenv, prod
6
+ from enum import Enum, auto
6
7
 
7
- ConstType = Union[float, int, bool]
8
+ ConstType = float|int|bool
8
9
 
9
10
  FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
10
11
 
@@ -16,16 +17,18 @@ class DTypeMetaClass(type):
16
17
  DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
17
18
  return ret
18
19
 
20
+ class AddrSpace(Enum): GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
21
+
19
22
  @dataclass(frozen=True, eq=False)
20
23
  class DType(metaclass=DTypeMetaClass):
21
24
  priority: int # this determines when things get upcasted
22
25
  itemsize: int
23
26
  name: str
24
- fmt: Optional[FmtStr]
27
+ fmt: FmtStr|None
25
28
  count: int
26
- _scalar: Optional[DType]
29
+ _scalar: DType|None
27
30
  @staticmethod
28
- def new(priority:int, itemsize:int, name:str, fmt:Optional[FmtStr]): return DType(priority, itemsize, name, fmt, 1, None)
31
+ def new(priority:int, itemsize:int, name:str, fmt:FmtStr|None): return DType(priority, itemsize, name, fmt, 1, None)
29
32
  def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
30
33
  def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
31
34
  def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
@@ -33,51 +36,62 @@ class DType(metaclass=DTypeMetaClass):
33
36
  def base(self): return self
34
37
  @property
35
38
  def vcount(self): return self.count
36
- @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
39
+ @functools.cache # pylint: disable=method-cache-max-size-none
37
40
  def vec(self, sz:int) -> DType:
38
41
  assert self.count == 1, f"can't vectorize {self} with size {sz}"
39
42
  if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
40
43
  return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
41
- def ptr(self, size=-1, local=False) -> PtrDType:
42
- return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1, size)
44
+ def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
45
+ return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
43
46
  def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
47
+ def nbytes(self): raise RuntimeError("only ptr types have nbytes")
48
+ @property
49
+ def min(self): return dtypes.min(self)
50
+ @property
51
+ def max(self): return dtypes.max(self)
44
52
 
45
53
  @dataclass(frozen=True, eq=False)
46
54
  class PtrDType(DType):
47
55
  _base: DType
48
- local: bool
56
+ addrspace: AddrSpace
49
57
  v: int
50
58
  size: int = -1 # -1 is unlimited size
51
59
  @property
52
60
  def base(self): return self._base
53
- @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
61
+ @functools.cache # pylint: disable=method-cache-max-size-none
54
62
  def vec(self, sz:int) -> DType:
55
63
  assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
56
64
  if sz == 1: return self # sz=1 is a scalar
57
- return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size)
58
- def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
65
+ if isinstance(self, ImageDType):
66
+ return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
67
+ return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
68
+ def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL): raise RuntimeError("can't make a pointer from a pointer")
69
+ def nbytes(self) -> int:
70
+ if self.size == -1: return 0 # TODO: this should be an exception
71
+ return self.size*self.itemsize
59
72
  @property
60
73
  def vcount(self): return self.v
61
74
  def __repr__(self):
62
- return f"{self.base.__repr__()}.ptr({self.size}{', local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
75
+ return f"{self.base.__repr__()}.ptr({self.size}{', '+str(self.addrspace) if self.addrspace != AddrSpace.GLOBAL else ''})" + \
76
+ (f'.vec({self.v})' if self.v != 1 else '')
63
77
 
64
78
  @dataclass(frozen=True, eq=False)
65
79
  class ImageDType(PtrDType):
66
80
  shape: tuple[int, ...] = () # shape of the Image
67
- def ptr(self, size=-1, local=False) -> PtrDType:
68
- assert not local, "images can't be local"
81
+ def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
82
+ assert addrspace == AddrSpace.GLOBAL, "images can't be local"
69
83
  return self
70
84
  def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
71
85
 
72
86
  class dtypes:
73
87
  @staticmethod
74
- @functools.lru_cache(None)
88
+ @functools.cache
75
89
  def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
76
90
  @staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
77
- @functools.lru_cache(None)
91
+ @functools.cache
78
92
  def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
79
93
  @staticmethod
80
- @functools.lru_cache(None)
94
+ @functools.cache
81
95
  def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
82
96
  @staticmethod
83
97
  def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
@@ -97,12 +111,12 @@ class dtypes:
97
111
  # TODO: should truncate here
98
112
  return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
99
113
  @staticmethod
100
- @functools.lru_cache(None)
114
+ @functools.cache
101
115
  def min(dtype:DType):
102
116
  if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
103
117
  return -float("inf") if dtypes.is_float(dtype) else False
104
118
  @staticmethod
105
- @functools.lru_cache(None)
119
+ @functools.cache
106
120
  def max(dtype:DType):
107
121
  if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype)
108
122
  return float("inf") if dtypes.is_float(dtype) else True
@@ -110,7 +124,8 @@ class dtypes:
110
124
  def finfo(dtype:DType) -> tuple[int, int]:
111
125
  """(exponent, mantissa)"""
112
126
  if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
113
- return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
127
+ return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52),
128
+ dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3: (4, 3)}[dtype]
114
129
  @staticmethod
115
130
  def fields() -> dict[str, DType]: return DTYPES_DICT
116
131
  void: Final[DType] = DType.new(-1, 0, "void", None)
@@ -123,11 +138,13 @@ class dtypes:
123
138
  uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
124
139
  int64: Final[DType] = DType.new(7, 8, "long", 'q')
125
140
  uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
126
- float16: Final[DType] = DType.new(9, 2, "half", 'e')
141
+ fp8e4m3: Final[DType] = DType.new(9, 1, "float8_e4m3", None)
142
+ fp8e5m2: Final[DType] = DType.new(10, 1, "float8_e5m2", None)
143
+ float16: Final[DType] = DType.new(11, 2, "half", 'e')
127
144
  # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
128
- bfloat16: Final[DType] = DType.new(10, 2, "__bf16", None)
129
- float32: Final[DType] = DType.new(11, 4, "float", 'f')
130
- float64: Final[DType] = DType.new(12, 8, "double", 'd')
145
+ bfloat16: Final[DType] = DType.new(12, 2, "__bf16", None)
146
+ float32: Final[DType] = DType.new(13, 4, "float", 'f')
147
+ float64: Final[DType] = DType.new(14, 8, "double", 'd')
131
148
 
132
149
  # dtype aliases
133
150
  half = float16; float = float32; double = float64 # noqa: E702
@@ -136,48 +153,66 @@ class dtypes:
136
153
 
137
154
  # NOTE: these are image dtypes
138
155
  @staticmethod
139
- def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, prod(shp), shp)
156
+ def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
140
157
  @staticmethod
141
- def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, prod(shp), shp)
158
+ def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
142
159
 
143
160
  default_float: ClassVar[DType] = float32
144
161
  default_int: ClassVar[DType] = int32
145
162
 
146
- floats = (float16, bfloat16, float32, float64)
163
+ fp8s = (fp8e4m3, fp8e5m2)
164
+ floats = fp8s + (float16, bfloat16, float32, float64)
147
165
  uints = (uint8, uint16, uint32, uint64)
148
166
  sints = (int8, int16, int32, int64)
149
167
  ints = uints + sints
168
+ all = floats + ints + (bool,)
150
169
 
151
170
  if (env_default_float := getenv("DEFAULT_FLOAT", "")):
152
171
  dtypes.default_float = getattr(dtypes, env_default_float.lower())
153
172
  assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
154
173
 
155
- DTypeLike = Union[str, DType]
156
- def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
174
+ DTypeLike = str|DType
175
+ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower())
157
176
 
158
177
  # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
159
178
  # we don't support weak type and complex type
160
179
  promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
161
180
  dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
162
181
  dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
182
+ dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16],
163
183
  dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
164
184
 
165
- @functools.lru_cache(None)
185
+ @functools.cache
166
186
  def _get_recursive_parents(dtype:DType) -> set[DType]:
167
187
  return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
168
- @functools.lru_cache(None)
188
+ @functools.cache
169
189
  def least_upper_dtype(*ds:DType) -> DType:
170
190
  return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
171
- def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
191
+ def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
172
192
 
173
193
  DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))}
174
194
  INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"}
175
195
 
196
+ @functools.cache
197
+ def can_safe_cast(dt0:DType, dt1:DType) -> bool:
198
+ # return if dt1 preserves value of dt0
199
+ # https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
200
+ if dt0 == dt1 or dt0 == dtypes.bool: return True
201
+ match dt1:
202
+ case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16)
203
+ case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16)
204
+ case dtypes.uint64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8)
205
+ case dtypes.uint32: return dt0 in (dtypes.uint16, dtypes.uint8)
206
+ case dtypes.int64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
207
+ case dtypes.int32: return dt0 in (dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
208
+ case dtypes.int16: return dt0 in (dtypes.uint8, dtypes.int8)
209
+ case _: return False
210
+
176
211
  def sum_acc_dtype(dt:DType):
177
212
  # default acc dtype for sum
178
213
  if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
179
214
  if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
180
- return least_upper_dtype(dt, dtypes.float)
215
+ return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
181
216
 
182
217
  def truncate_fp16(x):
183
218
  try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
@@ -185,15 +220,97 @@ def truncate_fp16(x):
185
220
 
186
221
  def truncate_bf16(x):
187
222
  max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
188
- if x > max_bf16 or x < -max_bf16: return math.copysign(math.inf, x)
223
+ if abs(x) > max_bf16: return math.copysign(math.inf, x)
189
224
  f32_int = struct.unpack('I', struct.pack('f', x))[0]
190
225
  bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
191
226
  return bf
192
227
 
228
+ # fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
229
+ def float_to_fp8(x: float, dtype: DType) -> int:
230
+ assert dtype in dtypes.fp8s, "Only for fp8s"
231
+ config = {
232
+ dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000,
233
+ "OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F},
234
+ dtypes.fp8e5m2: {"EXP_BIAS": 15, "SIGNIFICAND_BITS": 3, "MANTISSA_MASK": 0x3, "MINDENORM_O2": 0x3EE0000000000000,
235
+ "OVERFLOW_THRESHOLD": 0x40EE000000000000 - 1, "MAXNORM": 0x7B, "MINNORM": 0x3F10000000000000, "INF_VALUE": 0x7E}
236
+ }[dtype]
237
+ xbits, = struct.unpack('Q', struct.pack('d', x))
238
+ FP8_DP_HALF_ULP = 1 << (53 - config["SIGNIFICAND_BITS"] - 1)
239
+ sign = ((xbits >> 63) & 1) << 7
240
+ exp = (((xbits >> 52) & 0x7FF) - 1023 + config["EXP_BIAS"])
241
+ mantissa = (xbits >> (53 - config["SIGNIFICAND_BITS"])) & config["MANTISSA_MASK"]
242
+ absx = xbits & 0x7FFFFFFFFFFFFFFF
243
+
244
+ if absx <= config["MINDENORM_O2"]: res = 0
245
+ elif absx > 0x7FF0000000000000: res = 0x7F if dtype == dtypes.fp8e4m3 else 0x7E | mantissa
246
+ elif absx > config["OVERFLOW_THRESHOLD"]: res = config["MAXNORM"]
247
+ elif absx >= config["MINNORM"]:
248
+ res = ((exp << (config["SIGNIFICAND_BITS"] - 1)) | mantissa)
249
+ round_bits = xbits & ((FP8_DP_HALF_ULP << 1) - 1)
250
+ if (round_bits > FP8_DP_HALF_ULP) or (round_bits == FP8_DP_HALF_ULP and (mantissa & 1)): res = res + 1
251
+ else:
252
+ shift = 1 - exp
253
+ mantissa |= 1 << (config["SIGNIFICAND_BITS"] - 1)
254
+ res = (mantissa >> shift)
255
+ round_bits = (xbits | (1 << (53 - 1))) & ((FP8_DP_HALF_ULP << (shift + 1)) - 1)
256
+ if (round_bits > (FP8_DP_HALF_ULP << shift)) or (round_bits == (FP8_DP_HALF_ULP << shift) and (res & 1)):
257
+ res = res + 1
258
+
259
+ res |= sign
260
+ return int(res)
261
+
262
+ def fp8_to_float(x: int, dtype: DType) -> float:
263
+ assert dtype in dtypes.fp8s, "Only for fp8s"
264
+ ur = x << 8
265
+
266
+ if dtype == dtypes.fp8e5m2 and (ur & 0x7FFF) > 0x7C00: ur = 0x7FFF
267
+ elif dtype == dtypes.fp8e4m3:
268
+ sign = ur & 0x8000
269
+ exponent = ((ur & 0x7800) >> 1) + 0x2000
270
+ mantissa = (ur & 0x0700) >> 1
271
+ absx = x & 0x7F
272
+ if absx == 0x7F: ur = 0x7FFF
273
+ elif exponent == 0x2000:
274
+ if mantissa != 0:
275
+ mantissa <<= 1
276
+ while (mantissa & 0x0400) == 0:
277
+ mantissa <<= 1
278
+ exponent -= 0x0400
279
+ mantissa &= 0x03FF
280
+ else:
281
+ exponent = 0
282
+ ur = (sign | exponent) | mantissa
283
+ else:
284
+ ur = (sign | exponent) | mantissa
285
+
286
+ half_bytes = struct.pack('<H', ur)
287
+ float32_val = struct.unpack('e', half_bytes)[0]
288
+ return float(float32_val)
289
+
193
290
  truncate: dict[DType, Callable] = {dtypes.bool: bool,
194
291
  dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
292
+ **{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
195
293
  dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
196
294
  dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
197
295
  dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
198
296
  dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
199
297
  dtypes.int64: lambda x: ctypes.c_int64(x).value}
298
+
299
+ # numpy and torch dtype interop
300
+
301
+ def _to_np_dtype(dtype:DType) -> type|None:
302
+ import numpy as np
303
+ return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
304
+ def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
305
+ import numpy as np
306
+ return dtypes.fields()[np.dtype(npdtype).name]
307
+
308
+ @functools.cache
309
+ def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-defined] # noqa: F821
310
+ import numpy as np, torch
311
+ # NOTE: torch doesn't expose this mapping with a stable API
312
+ try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
313
+ except TypeError: return None
314
+ @functools.cache
315
+ def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
316
+ return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype]
tinygrad/engine/jit.py CHANGED
@@ -1,10 +1,10 @@
1
- from typing import TypeVar, Generic, Callable, Union, cast, Optional, Any
1
+ from typing import TypeVar, Generic, Callable, cast, Any
2
2
  import functools, collections
3
3
  from tinygrad.tensor import Tensor
4
- from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap
5
- from tinygrad.device import Buffer, Compiled, Device
4
+ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap
5
+ from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
6
6
  from tinygrad.dtype import DType
7
- from tinygrad.ops import UOp, Variable, sym_infer, Ops
7
+ from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
8
8
  from tinygrad.shape.shapetracker import ShapeTracker
9
9
  from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
10
10
  from tinygrad.engine.memory import _internal_memory_planner
@@ -14,48 +14,52 @@ from weakref import WeakKeyDictionary
14
14
 
15
15
  class GraphException(Exception): pass
16
16
 
17
+ def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
18
+
17
19
  def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
18
20
  # Split JIT cache into batches for faster graph execution.
19
21
  # This allows the accelerator to run some batches while subsequent graphs are still being updated.
20
22
  graphed_jit_cache: list[ExecItem] = []
21
23
  current_batch: list[ExecItem] = []
22
- current_device: Optional[Compiled] = None
24
+ current_batch_devs: list[Compiled] = []
23
25
 
24
26
  def flush_batch():
25
- nonlocal current_batch, current_device, max_batch_size
27
+ nonlocal current_batch, current_batch_devs, max_batch_size
26
28
  try:
27
- if current_device is None: raise GraphException("no device for graph")
29
+ if len(current_batch_devs) == 0: raise GraphException("no device for graph")
28
30
  if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
29
- graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
31
+ graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals)
30
32
  # clear jit inputs to allow their memory to be freed/reused
31
33
  for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
32
- graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Optional[Buffer]], input_rawbuffers)))
34
+ graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Buffer|None], input_rawbuffers)))
33
35
  max_batch_size *= 2
34
- if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
36
+ if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
35
37
  except GraphException as e:
36
38
  graphed_jit_cache.extend(current_batch)
37
- if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
39
+ if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_batch_devs[0]}: {e}")
38
40
  current_batch = []
39
- current_device = None
41
+ current_batch_devs = []
40
42
 
41
43
  for ji in jit_cache:
42
- if isinstance(ji.prg, ViewOp): continue
43
- ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
44
- if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
45
- elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
46
- ji_graph_dev = Device[ji.bufs[0].device]
47
-
48
- graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None
49
- can_be_graphed = ji_graph_dev and ji_graph_dev.graph
50
- can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
51
- type(ji_graph_dev) is type(current_device))
52
- can_extend_graph_batch = can_be_graphed and (max_batch_size == 0 or len(current_batch) < max_batch_size) and can_share_graph
44
+ match ji.prg:
45
+ case CompiledRunner(): ji_graph_dev = ji.prg.dev
46
+ case BufferXfer(): ji_graph_dev = Device[unwrap(ji.bufs[0]).device]
47
+ case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device not in {"CPU", "LLVM"}), None)
48
+ case ViewOp(): continue # ViewOps are just ignored
49
+ case _: ji_graph_dev = None # Everything else is not graphed and flushes existing graph if it's being constructed
50
+
51
+ # Check if this jit item can be graphed at all, so check if a new graph supports the current item.
52
+ can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji)
53
+
54
+ # Check if the current batch can be extended with this item.
55
+ can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \
56
+ graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji)
57
+ can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
58
+
59
+ # Flush the current batch if any, since it can't be extended or is full.
53
60
  if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
54
-
55
- if can_be_graphed: current_batch.append(ji)
56
- else: graphed_jit_cache.append(ji)
57
-
58
- current_device = ji_graph_dev
61
+ (current_batch if can_be_graphed else graphed_jit_cache).append(ji)
62
+ current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else []
59
63
 
60
64
  if len(current_batch) > 0: flush_batch()
61
65
  return graphed_jit_cache
@@ -72,8 +76,8 @@ class GraphRunner(Runner):
72
76
  def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
73
77
  self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
74
78
  self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
75
- self.var_vals_replace:dict[int, list[int]] = {}
76
- self.launch_dims_replace:dict[int, tuple[Optional[int], Optional[int]]] = {}
79
+ self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
80
+ self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
77
81
  self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
78
82
 
79
83
  def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
@@ -87,7 +91,7 @@ class GraphRunner(Runner):
87
91
  for j,ji in enumerate(jit_cache):
88
92
  estimates += ji.prg.estimates
89
93
  if isinstance(ji.prg, CompiledRunner):
90
- if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
94
+ if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v)) for i, v in enumerate(ji.prg.p.vars) if v not in ji.fixedvars]
91
95
 
92
96
  global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
93
97
  if global_dim_idx is not None or local_dim_idx is not None:
@@ -104,7 +108,7 @@ class GraphRunner(Runner):
104
108
  def updated_vars(self, var_vals: dict[Variable, int]):
105
109
  vals = [var_vals[v] for v in self.vars]
106
110
  for j, vidxs in self.var_vals_replace.items():
107
- for i, v in enumerate(vidxs): yield j, i, vals[v]
111
+ for i, v in vidxs: yield j, i, vals[v]
108
112
 
109
113
  def updated_launch_dims(self, var_vals: dict[Variable, int]):
110
114
  dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
@@ -120,21 +124,31 @@ class GraphRunner(Runner):
120
124
  if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
121
125
  if i in write:
122
126
  if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
123
- self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
127
+
128
+ for i,rawbuf in enumerate(rawbufs):
129
+ if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
124
130
  else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
125
131
 
126
132
  return list({id(x):x for x in wait_nodes}.values())
127
133
 
134
+ @staticmethod
135
+ def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool: return isinstance(ei.prg, CompiledRunner) and len(dedup(devs)) == 1
136
+
128
137
  # a marker for your graph supporting multiple devices of the same type
129
- class MultiGraphRunner(GraphRunner): pass
138
+ class MultiGraphRunner(GraphRunner):
139
+ @staticmethod
140
+ def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
141
+ # Devices must be the same type
142
+ return isinstance(ei.prg, (CompiledRunner, BufferXfer)) and len(dedup([type(Device[b.device]) for b in ei.bufs if b]+[type(d) for d in devs]))==1
143
+
144
+ def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
145
+ if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
146
+ if isinstance(ei.prg, (BufferCopy, BufferXfer)): return [cast(Buffer, ei.bufs[0])]
147
+ return []
130
148
 
131
149
  def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
132
150
  for ei in jit_cache:
133
- if any(b in depends for b in ei.bufs):
134
- if isinstance(ei.prg, CompiledRunner):
135
- depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins)
136
- if isinstance(ei.prg, (BufferCopy, BufferXfer)):
137
- depends.add(cast(Buffer, ei.bufs[0]))
151
+ if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei))
138
152
 
139
153
  ReturnType = TypeVar('ReturnType')
140
154
  @dataclass
@@ -143,11 +157,11 @@ class CapturedJit(Generic[ReturnType]):
143
157
  jit_cache: list[ExecItem]
144
158
  input_replace: dict[tuple[int, int], int]
145
159
  extra_view_inputs: list[tuple[int, int, str, int, DType]]
146
- expected_names: list[Union[int, str]]
160
+ expected_names: list[int|str]
147
161
  expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
148
162
 
149
163
  def __reduce__(self):
150
- # TODO: free_intermediates here?
164
+ # TODO: free_intermediates here? replan_buffers_memory_layout here?
151
165
  return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
152
166
  self.expected_names, self.expected_st_vars_dtype_device)
153
167
 
@@ -164,9 +178,19 @@ class CapturedJit(Generic[ReturnType]):
164
178
  depends: set[Buffer|None] = set([None])
165
179
  update_depends(depends, self.jit_cache)
166
180
  for b in depends:
167
- if b is not None: b.deallocate()
181
+ if b is not None:
182
+ if b.is_allocated(): b.deallocate()
183
+ if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
168
184
  self.__post_init__() # reset the graph state
169
185
 
186
+ def replan_buffers_memory_layout(self):
187
+ blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
188
+ asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
189
+ self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
190
+ for old, new in asgn.items():
191
+ if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer())
192
+ self.__post_init__()
193
+
170
194
  # jit exec
171
195
  def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType:
172
196
  # assign inputs
@@ -182,7 +206,7 @@ class CapturedJit(Generic[ReturnType]):
182
206
  if b is not None: b.ensure_allocated()
183
207
  # create graph if needed
184
208
  if JIT < 2:
185
- self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
209
+ self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=JIT_BATCH_SIZE.value)
186
210
  self._input_replace = get_input_replace(self._jit_cache, input_buffers)
187
211
  self._first_run = False
188
212
 
@@ -194,10 +218,11 @@ class CapturedJit(Generic[ReturnType]):
194
218
  def _prepare_jit_inputs(args, kwargs):
195
219
  input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
196
220
  names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
197
- if len(unrealized_tensors := [x for x in tensors if not x.lazydata.is_realized]): Tensor.realize(*unrealized_tensors)
198
- # TODO: should we be unpacking multi here?
199
- lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
200
- input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
221
+ if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
222
+ # TODO: this multi unpack stuff is not well tested.
223
+ lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
224
+ input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
225
+ for lb in lbs if lb.base.realized is not None])
201
226
  assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
202
227
  st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
203
228
  var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
@@ -205,16 +230,17 @@ def _prepare_jit_inputs(args, kwargs):
205
230
  return input_buffers, var_vals, names, st_vars_dtype_device
206
231
 
207
232
  class TinyJit(Generic[ReturnType]):
208
- def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False):
233
+ def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False):
209
234
  assert fxn or captured, "need either a function or a CapturedJit"
210
235
  self.fxn = fxn
211
- self.captured: Optional[CapturedJit] = captured
236
+ self.captured: CapturedJit|None = captured
212
237
  self.cnt: int = 2 if self.fxn is None else 0
213
238
  self.prune = prune
239
+ self.optimize = optimize
214
240
 
215
241
  def add_buffer(self, b:Buffer) -> Buffer:
216
242
  if found:=self._buffer_replace.get(b, None): return found
217
- if b.is_allocated() or b.lb_refcount > 0: return b
243
+ if b.is_allocated() or b.uop_refcount > 0: return b
218
244
  if b._base is not None:
219
245
  self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
220
246
  else:
@@ -222,7 +248,7 @@ class TinyJit(Generic[ReturnType]):
222
248
  return ret
223
249
 
224
250
  def add(self, ei:ExecItem):
225
- self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
251
+ self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars))
226
252
 
227
253
  def reset(self):
228
254
  assert self.fxn is not None, "can't reset without function"
@@ -281,8 +307,7 @@ class TinyJit(Generic[ReturnType]):
281
307
  if self.prune:
282
308
  depends = set(input_buffers)
283
309
  update_depends(depends, jit_cache)
284
- pruned, onetime = partition(jit_cache,
285
- lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
310
+ pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
286
311
  if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
287
312
  # run the onetime kernels here
288
313
  for ei in onetime:
@@ -294,13 +319,15 @@ class TinyJit(Generic[ReturnType]):
294
319
  # Exclude buffers involved in transfer ops to preserve parallelism.
295
320
  noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
296
321
  assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
297
- jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
322
+ jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
323
+ item.metadata, item.fixedvars) for item in jit_cache]
298
324
 
299
325
  input_replace = get_input_replace(jit_cache, input_buffers)
300
326
  if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
301
327
 
302
328
  # set this for next run
303
329
  self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
330
+ if self.optimize: self.captured.replan_buffers_memory_layout()
304
331
  elif self.cnt >= 2:
305
332
  # jit exec
306
333
  assert self.captured is not None