tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/dtype.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Final,
|
2
|
+
from typing import Final, ClassVar, Callable, Literal
|
3
3
|
import math, struct, ctypes, functools
|
4
4
|
from dataclasses import dataclass, fields
|
5
5
|
from tinygrad.helpers import getenv, prod
|
6
|
+
from enum import Enum, auto
|
6
7
|
|
7
|
-
ConstType =
|
8
|
+
ConstType = float|int|bool
|
8
9
|
|
9
10
|
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
|
10
11
|
|
@@ -16,16 +17,18 @@ class DTypeMetaClass(type):
|
|
16
17
|
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
|
17
18
|
return ret
|
18
19
|
|
20
|
+
class AddrSpace(Enum): GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
|
21
|
+
|
19
22
|
@dataclass(frozen=True, eq=False)
|
20
23
|
class DType(metaclass=DTypeMetaClass):
|
21
24
|
priority: int # this determines when things get upcasted
|
22
25
|
itemsize: int
|
23
26
|
name: str
|
24
|
-
fmt:
|
27
|
+
fmt: FmtStr|None
|
25
28
|
count: int
|
26
|
-
_scalar:
|
29
|
+
_scalar: DType|None
|
27
30
|
@staticmethod
|
28
|
-
def new(priority:int, itemsize:int, name:str, fmt:
|
31
|
+
def new(priority:int, itemsize:int, name:str, fmt:FmtStr|None): return DType(priority, itemsize, name, fmt, 1, None)
|
29
32
|
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
|
30
33
|
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
|
31
34
|
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
@@ -33,51 +36,62 @@ class DType(metaclass=DTypeMetaClass):
|
|
33
36
|
def base(self): return self
|
34
37
|
@property
|
35
38
|
def vcount(self): return self.count
|
36
|
-
@functools.
|
39
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
37
40
|
def vec(self, sz:int) -> DType:
|
38
41
|
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
39
42
|
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
40
43
|
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
41
|
-
def ptr(self, size=-1,
|
42
|
-
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self,
|
44
|
+
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
45
|
+
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
|
43
46
|
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
47
|
+
def nbytes(self): raise RuntimeError("only ptr types have nbytes")
|
48
|
+
@property
|
49
|
+
def min(self): return dtypes.min(self)
|
50
|
+
@property
|
51
|
+
def max(self): return dtypes.max(self)
|
44
52
|
|
45
53
|
@dataclass(frozen=True, eq=False)
|
46
54
|
class PtrDType(DType):
|
47
55
|
_base: DType
|
48
|
-
|
56
|
+
addrspace: AddrSpace
|
49
57
|
v: int
|
50
58
|
size: int = -1 # -1 is unlimited size
|
51
59
|
@property
|
52
60
|
def base(self): return self._base
|
53
|
-
@functools.
|
61
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
54
62
|
def vec(self, sz:int) -> DType:
|
55
63
|
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
56
64
|
if sz == 1: return self # sz=1 is a scalar
|
57
|
-
|
58
|
-
|
65
|
+
if isinstance(self, ImageDType):
|
66
|
+
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
|
67
|
+
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
|
68
|
+
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL): raise RuntimeError("can't make a pointer from a pointer")
|
69
|
+
def nbytes(self) -> int:
|
70
|
+
if self.size == -1: return 0 # TODO: this should be an exception
|
71
|
+
return self.size*self.itemsize
|
59
72
|
@property
|
60
73
|
def vcount(self): return self.v
|
61
74
|
def __repr__(self):
|
62
|
-
return f"{self.base.__repr__()}.ptr({self.size}{',
|
75
|
+
return f"{self.base.__repr__()}.ptr({self.size}{', '+str(self.addrspace) if self.addrspace != AddrSpace.GLOBAL else ''})" + \
|
76
|
+
(f'.vec({self.v})' if self.v != 1 else '')
|
63
77
|
|
64
78
|
@dataclass(frozen=True, eq=False)
|
65
79
|
class ImageDType(PtrDType):
|
66
80
|
shape: tuple[int, ...] = () # shape of the Image
|
67
|
-
def ptr(self, size=-1,
|
68
|
-
assert
|
81
|
+
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
82
|
+
assert addrspace == AddrSpace.GLOBAL, "images can't be local"
|
69
83
|
return self
|
70
84
|
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
|
71
85
|
|
72
86
|
class dtypes:
|
73
87
|
@staticmethod
|
74
|
-
@functools.
|
88
|
+
@functools.cache
|
75
89
|
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
|
76
90
|
@staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
|
77
|
-
@functools.
|
91
|
+
@functools.cache
|
78
92
|
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
|
79
93
|
@staticmethod
|
80
|
-
@functools.
|
94
|
+
@functools.cache
|
81
95
|
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
82
96
|
@staticmethod
|
83
97
|
def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
|
@@ -97,12 +111,12 @@ class dtypes:
|
|
97
111
|
# TODO: should truncate here
|
98
112
|
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
99
113
|
@staticmethod
|
100
|
-
@functools.
|
114
|
+
@functools.cache
|
101
115
|
def min(dtype:DType):
|
102
116
|
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
103
117
|
return -float("inf") if dtypes.is_float(dtype) else False
|
104
118
|
@staticmethod
|
105
|
-
@functools.
|
119
|
+
@functools.cache
|
106
120
|
def max(dtype:DType):
|
107
121
|
if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype)
|
108
122
|
return float("inf") if dtypes.is_float(dtype) else True
|
@@ -110,7 +124,8 @@ class dtypes:
|
|
110
124
|
def finfo(dtype:DType) -> tuple[int, int]:
|
111
125
|
"""(exponent, mantissa)"""
|
112
126
|
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
|
113
|
-
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)
|
127
|
+
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52),
|
128
|
+
dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3: (4, 3)}[dtype]
|
114
129
|
@staticmethod
|
115
130
|
def fields() -> dict[str, DType]: return DTYPES_DICT
|
116
131
|
void: Final[DType] = DType.new(-1, 0, "void", None)
|
@@ -123,11 +138,13 @@ class dtypes:
|
|
123
138
|
uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
|
124
139
|
int64: Final[DType] = DType.new(7, 8, "long", 'q')
|
125
140
|
uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
|
126
|
-
|
141
|
+
fp8e4m3: Final[DType] = DType.new(9, 1, "float8_e4m3", None)
|
142
|
+
fp8e5m2: Final[DType] = DType.new(10, 1, "float8_e5m2", None)
|
143
|
+
float16: Final[DType] = DType.new(11, 2, "half", 'e')
|
127
144
|
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
128
|
-
bfloat16: Final[DType] = DType.new(
|
129
|
-
float32: Final[DType] = DType.new(
|
130
|
-
float64: Final[DType] = DType.new(
|
145
|
+
bfloat16: Final[DType] = DType.new(12, 2, "__bf16", None)
|
146
|
+
float32: Final[DType] = DType.new(13, 4, "float", 'f')
|
147
|
+
float64: Final[DType] = DType.new(14, 8, "double", 'd')
|
131
148
|
|
132
149
|
# dtype aliases
|
133
150
|
half = float16; float = float32; double = float64 # noqa: E702
|
@@ -136,48 +153,66 @@ class dtypes:
|
|
136
153
|
|
137
154
|
# NOTE: these are image dtypes
|
138
155
|
@staticmethod
|
139
|
-
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32,
|
156
|
+
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
|
140
157
|
@staticmethod
|
141
|
-
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32,
|
158
|
+
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
|
142
159
|
|
143
160
|
default_float: ClassVar[DType] = float32
|
144
161
|
default_int: ClassVar[DType] = int32
|
145
162
|
|
146
|
-
|
163
|
+
fp8s = (fp8e4m3, fp8e5m2)
|
164
|
+
floats = fp8s + (float16, bfloat16, float32, float64)
|
147
165
|
uints = (uint8, uint16, uint32, uint64)
|
148
166
|
sints = (int8, int16, int32, int64)
|
149
167
|
ints = uints + sints
|
168
|
+
all = floats + ints + (bool,)
|
150
169
|
|
151
170
|
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
152
171
|
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
153
172
|
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
154
173
|
|
155
|
-
DTypeLike =
|
156
|
-
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
|
174
|
+
DTypeLike = str|DType
|
175
|
+
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower())
|
157
176
|
|
158
177
|
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
159
178
|
# we don't support weak type and complex type
|
160
179
|
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
161
180
|
dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
162
181
|
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
|
182
|
+
dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16],
|
163
183
|
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
164
184
|
|
165
|
-
@functools.
|
185
|
+
@functools.cache
|
166
186
|
def _get_recursive_parents(dtype:DType) -> set[DType]:
|
167
187
|
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
168
|
-
@functools.
|
188
|
+
@functools.cache
|
169
189
|
def least_upper_dtype(*ds:DType) -> DType:
|
170
190
|
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
171
|
-
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.
|
191
|
+
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
|
172
192
|
|
173
193
|
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))}
|
174
194
|
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"}
|
175
195
|
|
196
|
+
@functools.cache
|
197
|
+
def can_safe_cast(dt0:DType, dt1:DType) -> bool:
|
198
|
+
# return if dt1 preserves value of dt0
|
199
|
+
# https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
|
200
|
+
if dt0 == dt1 or dt0 == dtypes.bool: return True
|
201
|
+
match dt1:
|
202
|
+
case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16)
|
203
|
+
case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16)
|
204
|
+
case dtypes.uint64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8)
|
205
|
+
case dtypes.uint32: return dt0 in (dtypes.uint16, dtypes.uint8)
|
206
|
+
case dtypes.int64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
|
207
|
+
case dtypes.int32: return dt0 in (dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
|
208
|
+
case dtypes.int16: return dt0 in (dtypes.uint8, dtypes.int8)
|
209
|
+
case _: return False
|
210
|
+
|
176
211
|
def sum_acc_dtype(dt:DType):
|
177
212
|
# default acc dtype for sum
|
178
213
|
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
|
179
214
|
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
180
|
-
return least_upper_dtype(dt,
|
215
|
+
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
|
181
216
|
|
182
217
|
def truncate_fp16(x):
|
183
218
|
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
@@ -185,15 +220,97 @@ def truncate_fp16(x):
|
|
185
220
|
|
186
221
|
def truncate_bf16(x):
|
187
222
|
max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
|
188
|
-
if x > max_bf16
|
223
|
+
if abs(x) > max_bf16: return math.copysign(math.inf, x)
|
189
224
|
f32_int = struct.unpack('I', struct.pack('f', x))[0]
|
190
225
|
bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
|
191
226
|
return bf
|
192
227
|
|
228
|
+
# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
|
229
|
+
def float_to_fp8(x: float, dtype: DType) -> int:
|
230
|
+
assert dtype in dtypes.fp8s, "Only for fp8s"
|
231
|
+
config = {
|
232
|
+
dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000,
|
233
|
+
"OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F},
|
234
|
+
dtypes.fp8e5m2: {"EXP_BIAS": 15, "SIGNIFICAND_BITS": 3, "MANTISSA_MASK": 0x3, "MINDENORM_O2": 0x3EE0000000000000,
|
235
|
+
"OVERFLOW_THRESHOLD": 0x40EE000000000000 - 1, "MAXNORM": 0x7B, "MINNORM": 0x3F10000000000000, "INF_VALUE": 0x7E}
|
236
|
+
}[dtype]
|
237
|
+
xbits, = struct.unpack('Q', struct.pack('d', x))
|
238
|
+
FP8_DP_HALF_ULP = 1 << (53 - config["SIGNIFICAND_BITS"] - 1)
|
239
|
+
sign = ((xbits >> 63) & 1) << 7
|
240
|
+
exp = (((xbits >> 52) & 0x7FF) - 1023 + config["EXP_BIAS"])
|
241
|
+
mantissa = (xbits >> (53 - config["SIGNIFICAND_BITS"])) & config["MANTISSA_MASK"]
|
242
|
+
absx = xbits & 0x7FFFFFFFFFFFFFFF
|
243
|
+
|
244
|
+
if absx <= config["MINDENORM_O2"]: res = 0
|
245
|
+
elif absx > 0x7FF0000000000000: res = 0x7F if dtype == dtypes.fp8e4m3 else 0x7E | mantissa
|
246
|
+
elif absx > config["OVERFLOW_THRESHOLD"]: res = config["MAXNORM"]
|
247
|
+
elif absx >= config["MINNORM"]:
|
248
|
+
res = ((exp << (config["SIGNIFICAND_BITS"] - 1)) | mantissa)
|
249
|
+
round_bits = xbits & ((FP8_DP_HALF_ULP << 1) - 1)
|
250
|
+
if (round_bits > FP8_DP_HALF_ULP) or (round_bits == FP8_DP_HALF_ULP and (mantissa & 1)): res = res + 1
|
251
|
+
else:
|
252
|
+
shift = 1 - exp
|
253
|
+
mantissa |= 1 << (config["SIGNIFICAND_BITS"] - 1)
|
254
|
+
res = (mantissa >> shift)
|
255
|
+
round_bits = (xbits | (1 << (53 - 1))) & ((FP8_DP_HALF_ULP << (shift + 1)) - 1)
|
256
|
+
if (round_bits > (FP8_DP_HALF_ULP << shift)) or (round_bits == (FP8_DP_HALF_ULP << shift) and (res & 1)):
|
257
|
+
res = res + 1
|
258
|
+
|
259
|
+
res |= sign
|
260
|
+
return int(res)
|
261
|
+
|
262
|
+
def fp8_to_float(x: int, dtype: DType) -> float:
|
263
|
+
assert dtype in dtypes.fp8s, "Only for fp8s"
|
264
|
+
ur = x << 8
|
265
|
+
|
266
|
+
if dtype == dtypes.fp8e5m2 and (ur & 0x7FFF) > 0x7C00: ur = 0x7FFF
|
267
|
+
elif dtype == dtypes.fp8e4m3:
|
268
|
+
sign = ur & 0x8000
|
269
|
+
exponent = ((ur & 0x7800) >> 1) + 0x2000
|
270
|
+
mantissa = (ur & 0x0700) >> 1
|
271
|
+
absx = x & 0x7F
|
272
|
+
if absx == 0x7F: ur = 0x7FFF
|
273
|
+
elif exponent == 0x2000:
|
274
|
+
if mantissa != 0:
|
275
|
+
mantissa <<= 1
|
276
|
+
while (mantissa & 0x0400) == 0:
|
277
|
+
mantissa <<= 1
|
278
|
+
exponent -= 0x0400
|
279
|
+
mantissa &= 0x03FF
|
280
|
+
else:
|
281
|
+
exponent = 0
|
282
|
+
ur = (sign | exponent) | mantissa
|
283
|
+
else:
|
284
|
+
ur = (sign | exponent) | mantissa
|
285
|
+
|
286
|
+
half_bytes = struct.pack('<H', ur)
|
287
|
+
float32_val = struct.unpack('e', half_bytes)[0]
|
288
|
+
return float(float32_val)
|
289
|
+
|
193
290
|
truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
194
291
|
dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
|
292
|
+
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
|
195
293
|
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
196
294
|
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
197
295
|
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
198
296
|
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
|
199
297
|
dtypes.int64: lambda x: ctypes.c_int64(x).value}
|
298
|
+
|
299
|
+
# numpy and torch dtype interop
|
300
|
+
|
301
|
+
def _to_np_dtype(dtype:DType) -> type|None:
|
302
|
+
import numpy as np
|
303
|
+
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
304
|
+
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
305
|
+
import numpy as np
|
306
|
+
return dtypes.fields()[np.dtype(npdtype).name]
|
307
|
+
|
308
|
+
@functools.cache
|
309
|
+
def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-defined] # noqa: F821
|
310
|
+
import numpy as np, torch
|
311
|
+
# NOTE: torch doesn't expose this mapping with a stable API
|
312
|
+
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
|
313
|
+
except TypeError: return None
|
314
|
+
@functools.cache
|
315
|
+
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
316
|
+
return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype]
|
tinygrad/engine/jit.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
|
-
from typing import TypeVar, Generic, Callable,
|
1
|
+
from typing import TypeVar, Generic, Callable, cast, Any
|
2
2
|
import functools, collections
|
3
3
|
from tinygrad.tensor import Tensor
|
4
|
-
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap
|
5
|
-
from tinygrad.device import Buffer, Compiled, Device
|
4
|
+
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap
|
5
|
+
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
6
6
|
from tinygrad.dtype import DType
|
7
|
-
from tinygrad.ops import UOp, Variable, sym_infer, Ops
|
7
|
+
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
|
8
8
|
from tinygrad.shape.shapetracker import ShapeTracker
|
9
9
|
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
10
10
|
from tinygrad.engine.memory import _internal_memory_planner
|
@@ -14,48 +14,52 @@ from weakref import WeakKeyDictionary
|
|
14
14
|
|
15
15
|
class GraphException(Exception): pass
|
16
16
|
|
17
|
+
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
18
|
+
|
17
19
|
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
|
18
20
|
# Split JIT cache into batches for faster graph execution.
|
19
21
|
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
20
22
|
graphed_jit_cache: list[ExecItem] = []
|
21
23
|
current_batch: list[ExecItem] = []
|
22
|
-
|
24
|
+
current_batch_devs: list[Compiled] = []
|
23
25
|
|
24
26
|
def flush_batch():
|
25
|
-
nonlocal current_batch,
|
27
|
+
nonlocal current_batch, current_batch_devs, max_batch_size
|
26
28
|
try:
|
27
|
-
if
|
29
|
+
if len(current_batch_devs) == 0: raise GraphException("no device for graph")
|
28
30
|
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
|
29
|
-
graph_runner =
|
31
|
+
graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals)
|
30
32
|
# clear jit inputs to allow their memory to be freed/reused
|
31
33
|
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
32
|
-
graphed_jit_cache.append(ExecItem(graph_runner, cast(list[
|
34
|
+
graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Buffer|None], input_rawbuffers)))
|
33
35
|
max_batch_size *= 2
|
34
|
-
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {
|
36
|
+
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
|
35
37
|
except GraphException as e:
|
36
38
|
graphed_jit_cache.extend(current_batch)
|
37
|
-
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {
|
39
|
+
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_batch_devs[0]}: {e}")
|
38
40
|
current_batch = []
|
39
|
-
|
41
|
+
current_batch_devs = []
|
40
42
|
|
41
43
|
for ji in jit_cache:
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
44
|
+
match ji.prg:
|
45
|
+
case CompiledRunner(): ji_graph_dev = ji.prg.dev
|
46
|
+
case BufferXfer(): ji_graph_dev = Device[unwrap(ji.bufs[0]).device]
|
47
|
+
case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device not in {"CPU", "LLVM"}), None)
|
48
|
+
case ViewOp(): continue # ViewOps are just ignored
|
49
|
+
case _: ji_graph_dev = None # Everything else is not graphed and flushes existing graph if it's being constructed
|
50
|
+
|
51
|
+
# Check if this jit item can be graphed at all, so check if a new graph supports the current item.
|
52
|
+
can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji)
|
53
|
+
|
54
|
+
# Check if the current batch can be extended with this item.
|
55
|
+
can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \
|
56
|
+
graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji)
|
57
|
+
can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
|
58
|
+
|
59
|
+
# Flush the current batch if any, since it can't be extended or is full.
|
53
60
|
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
54
|
-
|
55
|
-
if can_be_graphed
|
56
|
-
else: graphed_jit_cache.append(ji)
|
57
|
-
|
58
|
-
current_device = ji_graph_dev
|
61
|
+
(current_batch if can_be_graphed else graphed_jit_cache).append(ji)
|
62
|
+
current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else []
|
59
63
|
|
60
64
|
if len(current_batch) > 0: flush_batch()
|
61
65
|
return graphed_jit_cache
|
@@ -72,8 +76,8 @@ class GraphRunner(Runner):
|
|
72
76
|
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
73
77
|
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
74
78
|
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
75
|
-
self.var_vals_replace:dict[int, list[int]] = {}
|
76
|
-
self.launch_dims_replace:dict[int, tuple[
|
79
|
+
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
80
|
+
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
|
77
81
|
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
78
82
|
|
79
83
|
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
@@ -87,7 +91,7 @@ class GraphRunner(Runner):
|
|
87
91
|
for j,ji in enumerate(jit_cache):
|
88
92
|
estimates += ji.prg.estimates
|
89
93
|
if isinstance(ji.prg, CompiledRunner):
|
90
|
-
if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
|
94
|
+
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v)) for i, v in enumerate(ji.prg.p.vars) if v not in ji.fixedvars]
|
91
95
|
|
92
96
|
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
93
97
|
if global_dim_idx is not None or local_dim_idx is not None:
|
@@ -104,7 +108,7 @@ class GraphRunner(Runner):
|
|
104
108
|
def updated_vars(self, var_vals: dict[Variable, int]):
|
105
109
|
vals = [var_vals[v] for v in self.vars]
|
106
110
|
for j, vidxs in self.var_vals_replace.items():
|
107
|
-
for i, v in
|
111
|
+
for i, v in vidxs: yield j, i, vals[v]
|
108
112
|
|
109
113
|
def updated_launch_dims(self, var_vals: dict[Variable, int]):
|
110
114
|
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
@@ -120,21 +124,31 @@ class GraphRunner(Runner):
|
|
120
124
|
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
|
121
125
|
if i in write:
|
122
126
|
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
|
123
|
-
|
127
|
+
|
128
|
+
for i,rawbuf in enumerate(rawbufs):
|
129
|
+
if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
124
130
|
else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
|
125
131
|
|
126
132
|
return list({id(x):x for x in wait_nodes}.values())
|
127
133
|
|
134
|
+
@staticmethod
|
135
|
+
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool: return isinstance(ei.prg, CompiledRunner) and len(dedup(devs)) == 1
|
136
|
+
|
128
137
|
# a marker for your graph supporting multiple devices of the same type
|
129
|
-
class MultiGraphRunner(GraphRunner):
|
138
|
+
class MultiGraphRunner(GraphRunner):
|
139
|
+
@staticmethod
|
140
|
+
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
|
141
|
+
# Devices must be the same type
|
142
|
+
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) and len(dedup([type(Device[b.device]) for b in ei.bufs if b]+[type(d) for d in devs]))==1
|
143
|
+
|
144
|
+
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
|
145
|
+
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
|
146
|
+
if isinstance(ei.prg, (BufferCopy, BufferXfer)): return [cast(Buffer, ei.bufs[0])]
|
147
|
+
return []
|
130
148
|
|
131
149
|
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
132
150
|
for ei in jit_cache:
|
133
|
-
if any(b in depends for b in ei.bufs):
|
134
|
-
if isinstance(ei.prg, CompiledRunner):
|
135
|
-
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins)
|
136
|
-
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
|
137
|
-
depends.add(cast(Buffer, ei.bufs[0]))
|
151
|
+
if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei))
|
138
152
|
|
139
153
|
ReturnType = TypeVar('ReturnType')
|
140
154
|
@dataclass
|
@@ -143,11 +157,11 @@ class CapturedJit(Generic[ReturnType]):
|
|
143
157
|
jit_cache: list[ExecItem]
|
144
158
|
input_replace: dict[tuple[int, int], int]
|
145
159
|
extra_view_inputs: list[tuple[int, int, str, int, DType]]
|
146
|
-
expected_names: list[
|
160
|
+
expected_names: list[int|str]
|
147
161
|
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
|
148
162
|
|
149
163
|
def __reduce__(self):
|
150
|
-
# TODO: free_intermediates here?
|
164
|
+
# TODO: free_intermediates here? replan_buffers_memory_layout here?
|
151
165
|
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
152
166
|
self.expected_names, self.expected_st_vars_dtype_device)
|
153
167
|
|
@@ -164,9 +178,19 @@ class CapturedJit(Generic[ReturnType]):
|
|
164
178
|
depends: set[Buffer|None] = set([None])
|
165
179
|
update_depends(depends, self.jit_cache)
|
166
180
|
for b in depends:
|
167
|
-
if b is not None:
|
181
|
+
if b is not None:
|
182
|
+
if b.is_allocated(): b.deallocate()
|
183
|
+
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
|
168
184
|
self.__post_init__() # reset the graph state
|
169
185
|
|
186
|
+
def replan_buffers_memory_layout(self):
|
187
|
+
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
|
188
|
+
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
|
189
|
+
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
|
190
|
+
for old, new in asgn.items():
|
191
|
+
if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer())
|
192
|
+
self.__post_init__()
|
193
|
+
|
170
194
|
# jit exec
|
171
195
|
def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType:
|
172
196
|
# assign inputs
|
@@ -182,7 +206,7 @@ class CapturedJit(Generic[ReturnType]):
|
|
182
206
|
if b is not None: b.ensure_allocated()
|
183
207
|
# create graph if needed
|
184
208
|
if JIT < 2:
|
185
|
-
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=
|
209
|
+
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=JIT_BATCH_SIZE.value)
|
186
210
|
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
|
187
211
|
self._first_run = False
|
188
212
|
|
@@ -194,10 +218,11 @@ class CapturedJit(Generic[ReturnType]):
|
|
194
218
|
def _prepare_jit_inputs(args, kwargs):
|
195
219
|
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
196
220
|
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
197
|
-
if len(unrealized_tensors := [x for x in tensors if not x.
|
198
|
-
# TODO:
|
199
|
-
lbs: list[UOp] = flatten([t.
|
200
|
-
input_buffers: list[Buffer] = [
|
221
|
+
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
222
|
+
# TODO: this multi unpack stuff is not well tested.
|
223
|
+
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
224
|
+
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
|
225
|
+
for lb in lbs if lb.base.realized is not None])
|
201
226
|
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
202
227
|
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
|
203
228
|
var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
@@ -205,16 +230,17 @@ def _prepare_jit_inputs(args, kwargs):
|
|
205
230
|
return input_buffers, var_vals, names, st_vars_dtype_device
|
206
231
|
|
207
232
|
class TinyJit(Generic[ReturnType]):
|
208
|
-
def __init__(self, fxn:
|
233
|
+
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False):
|
209
234
|
assert fxn or captured, "need either a function or a CapturedJit"
|
210
235
|
self.fxn = fxn
|
211
|
-
self.captured:
|
236
|
+
self.captured: CapturedJit|None = captured
|
212
237
|
self.cnt: int = 2 if self.fxn is None else 0
|
213
238
|
self.prune = prune
|
239
|
+
self.optimize = optimize
|
214
240
|
|
215
241
|
def add_buffer(self, b:Buffer) -> Buffer:
|
216
242
|
if found:=self._buffer_replace.get(b, None): return found
|
217
|
-
if b.is_allocated() or b.
|
243
|
+
if b.is_allocated() or b.uop_refcount > 0: return b
|
218
244
|
if b._base is not None:
|
219
245
|
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
|
220
246
|
else:
|
@@ -222,7 +248,7 @@ class TinyJit(Generic[ReturnType]):
|
|
222
248
|
return ret
|
223
249
|
|
224
250
|
def add(self, ei:ExecItem):
|
225
|
-
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
|
251
|
+
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars))
|
226
252
|
|
227
253
|
def reset(self):
|
228
254
|
assert self.fxn is not None, "can't reset without function"
|
@@ -281,8 +307,7 @@ class TinyJit(Generic[ReturnType]):
|
|
281
307
|
if self.prune:
|
282
308
|
depends = set(input_buffers)
|
283
309
|
update_depends(depends, jit_cache)
|
284
|
-
pruned, onetime = partition(jit_cache,
|
285
|
-
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
|
310
|
+
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
|
286
311
|
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
287
312
|
# run the onetime kernels here
|
288
313
|
for ei in onetime:
|
@@ -294,13 +319,15 @@ class TinyJit(Generic[ReturnType]):
|
|
294
319
|
# Exclude buffers involved in transfer ops to preserve parallelism.
|
295
320
|
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
|
296
321
|
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
297
|
-
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]
|
322
|
+
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
|
323
|
+
item.metadata, item.fixedvars) for item in jit_cache]
|
298
324
|
|
299
325
|
input_replace = get_input_replace(jit_cache, input_buffers)
|
300
326
|
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
301
327
|
|
302
328
|
# set this for next run
|
303
329
|
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
330
|
+
if self.optimize: self.captured.replan_buffers_memory_layout()
|
304
331
|
elif self.cnt >= 2:
|
305
332
|
# jit exec
|
306
333
|
assert self.captured is not None
|