tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/nn/state.py
CHANGED
@@ -1,37 +1,80 @@
|
|
1
|
-
import os, json, pathlib, zipfile, pickle
|
1
|
+
import os, json, pathlib, zipfile, pickle, tarfile, struct
|
2
2
|
from tqdm import tqdm
|
3
|
-
from typing import Dict, Union, List
|
3
|
+
from typing import Dict, Union, List, Optional, Any, Tuple
|
4
4
|
from tinygrad.tensor import Tensor
|
5
|
-
from tinygrad.
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.
|
5
|
+
from tinygrad.dtype import dtypes
|
6
|
+
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters
|
7
|
+
from tinygrad.shape.view import strides_for_shape
|
8
|
+
from tinygrad.multi import MultiLazyBuffer
|
8
9
|
|
9
|
-
safe_dtypes = {"
|
10
|
+
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
|
11
|
+
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
|
10
12
|
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
11
13
|
|
12
|
-
def
|
14
|
+
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
15
|
+
"""
|
16
|
+
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
|
17
|
+
"""
|
13
18
|
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
14
|
-
json_len = t[0:
|
15
|
-
|
16
|
-
|
19
|
+
json_len = t[0:8].bitcast(dtypes.int64).item()
|
20
|
+
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
|
21
|
+
|
22
|
+
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
23
|
+
"""
|
24
|
+
Loads a .safetensor file from disk, returning the state_dict.
|
17
25
|
|
18
|
-
|
19
|
-
|
26
|
+
```python
|
27
|
+
state_dict = nn.state.safe_load("test.safetensor")
|
28
|
+
```
|
29
|
+
"""
|
30
|
+
t, json_len, metadata = safe_load_metadata(fn)
|
31
|
+
ret = {}
|
32
|
+
for k,v in metadata.items():
|
33
|
+
if k == "__metadata__": continue
|
34
|
+
dtype = safe_dtypes[v['dtype']]
|
35
|
+
sz = (v['data_offsets'][1]-v['data_offsets'][0])
|
36
|
+
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
|
37
|
+
return ret
|
38
|
+
|
39
|
+
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
40
|
+
"""
|
41
|
+
Saves a state_dict to disk in a .safetensor file with optional metadata.
|
42
|
+
|
43
|
+
```python
|
44
|
+
t = nn.Tensor([1, 2, 3])
|
45
|
+
nn.state.safe_save({'t':t}, "test.safetensor")
|
46
|
+
```
|
47
|
+
"""
|
48
|
+
headers, offset = {}, 0
|
49
|
+
if metadata: headers['__metadata__'] = metadata
|
20
50
|
for k,v in tensors.items():
|
21
|
-
|
51
|
+
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
22
52
|
offset += v.nbytes()
|
23
|
-
j = json.dumps(
|
53
|
+
j = json.dumps(headers, separators=(',', ':'))
|
24
54
|
j += "\x20"*((8-len(j)%8)%8)
|
25
55
|
pathlib.Path(fn).unlink(missing_ok=True)
|
26
56
|
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
27
|
-
t[0:
|
28
|
-
t[8:8+len(j)].assign(
|
57
|
+
t[0:8].bitcast(dtypes.int64).assign([len(j)])
|
58
|
+
t[8:8+len(j)].assign(list(j.encode('utf-8')))
|
29
59
|
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
30
60
|
|
31
61
|
# state dict
|
32
62
|
|
33
63
|
from collections import OrderedDict
|
34
64
|
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
65
|
+
"""
|
66
|
+
Returns a state_dict of the object, with optional prefix.
|
67
|
+
|
68
|
+
```python exec="true" source="above" session="tensor" result="python"
|
69
|
+
class Net:
|
70
|
+
def __init__(self):
|
71
|
+
self.l1 = nn.Linear(4, 5)
|
72
|
+
self.l2 = nn.Linear(5, 6)
|
73
|
+
|
74
|
+
net = Net()
|
75
|
+
print(nn.state.get_state_dict(net).keys())
|
76
|
+
```
|
77
|
+
"""
|
35
78
|
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
36
79
|
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
|
37
80
|
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
@@ -42,39 +85,71 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
|
42
85
|
elif isinstance(obj, dict):
|
43
86
|
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
44
87
|
return state_dict
|
45
|
-
def get_parameters(obj) -> List[Tensor]:
|
88
|
+
def get_parameters(obj) -> List[Tensor]:
|
89
|
+
"""
|
90
|
+
```python exec="true" source="above" session="tensor" result="python"
|
91
|
+
class Net:
|
92
|
+
def __init__(self):
|
93
|
+
self.l1 = nn.Linear(4, 5)
|
94
|
+
self.l2 = nn.Linear(5, 6)
|
95
|
+
|
96
|
+
net = Net()
|
97
|
+
print(len(nn.state.get_parameters(net)))
|
98
|
+
```
|
99
|
+
"""
|
100
|
+
return list(get_state_dict(obj).values())
|
101
|
+
|
102
|
+
def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
|
103
|
+
"""
|
104
|
+
Loads a state_dict into a model.
|
46
105
|
|
47
|
-
|
48
|
-
|
106
|
+
```python
|
107
|
+
class Net:
|
108
|
+
def __init__(self):
|
109
|
+
self.l1 = nn.Linear(4, 5)
|
110
|
+
self.l2 = nn.Linear(5, 6)
|
111
|
+
|
112
|
+
net = Net()
|
113
|
+
state_dict = nn.state.get_state_dict(net)
|
114
|
+
nn.state.load_state_dict(net, state_dict)
|
115
|
+
```
|
116
|
+
"""
|
117
|
+
start_mem_used = GlobalCounters.mem_used
|
118
|
+
with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501
|
49
119
|
model_state_dict = get_state_dict(model)
|
50
|
-
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
51
|
-
|
120
|
+
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
121
|
+
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
122
|
+
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
|
52
123
|
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
|
53
124
|
if k not in state_dict and not strict:
|
54
125
|
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
55
126
|
continue
|
56
|
-
|
127
|
+
if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
|
128
|
+
if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
|
129
|
+
else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
|
130
|
+
else: v.replace(state_dict[k].to(v.device)).realize()
|
131
|
+
if consume: del state_dict[k]
|
57
132
|
|
58
133
|
# torch support!
|
59
134
|
|
60
|
-
def torch_load(fn:str):
|
135
|
+
def torch_load(fn:str) -> Dict[str, Tensor]:
|
136
|
+
"""
|
137
|
+
Loads a torch .pth file from disk.
|
138
|
+
|
139
|
+
```python
|
140
|
+
state_dict = nn.state.torch_load("test.pth")
|
141
|
+
```
|
142
|
+
"""
|
61
143
|
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
62
144
|
|
63
|
-
offsets: Dict[str, int] = {}
|
64
|
-
lens: Dict[str, int] = {}
|
65
|
-
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
145
|
+
offsets: Dict[Union[str, int], int] = {}
|
146
|
+
lens: Dict[Union[str, int], int] = {}
|
147
|
+
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
|
66
148
|
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
67
149
|
lens[storage[2]] = storage[4] * storage[1].itemsize
|
68
150
|
if storage[2] not in offsets: return None
|
69
151
|
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
70
|
-
ret = t[byte_offset:byte_offset+prod(size)].
|
71
|
-
# convert bfloat16 -> float16 using LLVM for Llama 2
|
72
|
-
# upstream LLaMA also does this conversion:
|
73
|
-
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
|
74
|
-
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
|
75
|
-
if storage[1] == dtypes.bfloat16:
|
76
|
-
ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
|
77
|
-
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
|
152
|
+
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
|
78
153
|
|
79
154
|
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
80
155
|
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
@@ -82,13 +157,20 @@ def torch_load(fn:str):
|
|
82
157
|
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
83
158
|
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
84
159
|
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
85
|
-
if DEBUG >=
|
160
|
+
if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
|
161
|
+
assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
|
86
162
|
# TODO: find a nice way to support all shapetracker on disktensors
|
87
|
-
|
163
|
+
# TODO: BUG: a ".realize()" is needed here for 'GPU=1 python3 test/models/test_efficientnet.py TestEfficientNet.test_car'
|
164
|
+
ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes).realize()
|
88
165
|
|
89
166
|
return ret.reshape(size)
|
90
167
|
|
91
|
-
|
168
|
+
class Parameter:
|
169
|
+
def __setstate__(self, state): self.tensor = state[0]
|
170
|
+
|
171
|
+
deserialized_objects: Dict[str, Any] = {}
|
172
|
+
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32,
|
173
|
+
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
92
174
|
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
93
175
|
class Dummy: pass
|
94
176
|
class TorchPickle(pickle.Unpickler):
|
@@ -98,9 +180,9 @@ def torch_load(fn:str):
|
|
98
180
|
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
|
99
181
|
return Dummy
|
100
182
|
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
101
|
-
def persistent_load(self, pid): return pid
|
183
|
+
def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
|
102
184
|
|
103
|
-
if
|
185
|
+
if zipfile.is_zipfile(fn):
|
104
186
|
myzip = zipfile.ZipFile(fn, 'r')
|
105
187
|
base_name = myzip.namelist()[0].split('/', 1)[0]
|
106
188
|
for n in myzip.namelist():
|
@@ -109,6 +191,21 @@ def torch_load(fn:str):
|
|
109
191
|
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
110
192
|
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
111
193
|
return TorchPickle(myfile).load()
|
194
|
+
elif tarfile.is_tarfile(fn):
|
195
|
+
with tarfile.open(fn, "r") as tar:
|
196
|
+
storages_offset = tar.getmember('storages').offset_data
|
197
|
+
f = unwrap(tar.extractfile('storages'))
|
198
|
+
for i in range(TorchPickle(f).load()): # num_storages
|
199
|
+
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
|
200
|
+
offsets[key] = storages_offset + f.tell()
|
201
|
+
f.seek(sz*storage_type.itemsize, 1)
|
202
|
+
f = unwrap(tar.extractfile('tensors'))
|
203
|
+
for _ in range(TorchPickle(f).load()): # num_tensors
|
204
|
+
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
|
205
|
+
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
|
206
|
+
storage_offset = struct.unpack('<q', f.read(8))[0]
|
207
|
+
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
|
208
|
+
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
|
112
209
|
else:
|
113
210
|
with open(fn, "rb") as f:
|
114
211
|
pkl = TorchPickle(f)
|
tinygrad/ops.py
CHANGED
@@ -1,219 +1,136 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
2
|
+
from typing import Union, Tuple, Any, List, Dict, Callable
|
3
|
+
import functools, hashlib, math, operator, ctypes
|
3
4
|
from enum import Enum, auto
|
4
|
-
from
|
5
|
-
from tinygrad.helpers import
|
6
|
-
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from tinygrad.helpers import prod, dedup
|
7
|
+
from tinygrad.dtype import dtypes, DType, ConstType
|
8
|
+
from tinygrad.shape.symbolic import Variable, sint
|
9
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
7
10
|
|
8
11
|
# these are the llops your accelerator must implement, along with toCpu
|
9
12
|
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
10
13
|
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
11
|
-
# NOTE:
|
12
|
-
class UnaryOps(Enum):
|
13
|
-
|
14
|
-
|
15
|
-
class
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
14
|
+
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
15
|
+
class UnaryOps(Enum):
|
16
|
+
"""A -> A (elementwise)"""
|
17
|
+
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
|
18
|
+
class BinaryOps(Enum):
|
19
|
+
"""A + A -> A (elementwise)"""
|
20
|
+
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
|
21
|
+
class TernaryOps(Enum):
|
22
|
+
"""A + A + A -> A (elementwise)"""
|
23
|
+
WHERE = auto(); MULACC = auto() # noqa: E702
|
24
|
+
class ReduceOps(Enum):
|
25
|
+
"""A -> B (reduce)"""
|
26
|
+
SUM = auto(); MAX = auto() # noqa: E702
|
27
|
+
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
28
|
+
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
|
29
|
+
|
30
|
+
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
|
31
|
+
|
32
|
+
# do not preserve f(0) = 0
|
33
|
+
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
|
34
|
+
|
35
|
+
@dataclass(frozen=True)
|
36
|
+
class MemBuffer:
|
37
|
+
idx: int
|
38
|
+
dtype: DType
|
39
|
+
st: ShapeTracker
|
40
|
+
|
41
|
+
@dataclass(frozen=True)
|
42
|
+
class ConstBuffer:
|
43
|
+
val: ConstType
|
44
|
+
dtype: DType
|
45
|
+
st: ShapeTracker
|
46
|
+
|
47
|
+
@dataclass(frozen=True, eq=False)
|
22
48
|
class LazyOp:
|
23
|
-
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
|
24
49
|
op: Op
|
25
|
-
src: Tuple[
|
26
|
-
arg: Any
|
27
|
-
|
28
|
-
|
29
|
-
self.op
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
50
|
+
src: Tuple[LazyOp, ...] = ()
|
51
|
+
arg: Any = None
|
52
|
+
def cached_compare(self, x, context):
|
53
|
+
if id(self) == id(x): return True
|
54
|
+
if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False
|
55
|
+
if (key := (id(self), id(x))) in context: return context[key]
|
56
|
+
ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
|
57
|
+
return ret
|
58
|
+
def __eq__(self, x): return self.cached_compare(x, context={})
|
34
59
|
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
35
|
-
|
36
|
-
def
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
# Any == Union[LazyBuffer, DeviceBuffer]
|
41
|
-
def map_buffers(self, real_srcs: Dict[Any, Any]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
|
42
|
-
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
|
43
|
-
|
44
|
-
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
|
45
|
-
assert self.op in BinaryOps or self.op in UnaryOps or self.op in TernaryOps
|
46
|
-
srcs = [z.replace_with_movement_ops(ops) for z in self.src]
|
47
|
-
return srcs[0].e(self.op, *srcs[1:], arg=self.arg) # type: ignore
|
60
|
+
@functools.cached_property
|
61
|
+
def dtype(self) -> DType:
|
62
|
+
if self.op in BufferOps: return self.arg.dtype
|
63
|
+
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
|
64
|
+
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else self.src[-1].dtype
|
48
65
|
|
49
|
-
@property
|
50
|
-
def st(self): raise NotImplementedError
|
51
|
-
@property
|
52
|
-
def children(self): raise NotImplementedError
|
53
|
-
@property
|
54
|
-
def shape(self): raise NotImplementedError
|
55
|
-
@property
|
56
|
-
def realized(self): raise NotImplementedError
|
57
|
-
@property
|
58
|
-
def optype(self): raise NotImplementedError
|
59
|
-
def realize(self): raise NotImplementedError
|
60
|
-
|
61
|
-
# movement ops
|
62
|
-
def reshape(self, _): raise NotImplementedError
|
63
|
-
def pad(self, _): raise NotImplementedError
|
64
|
-
def expand(self, _): raise NotImplementedError
|
65
|
-
def permute(self, _): raise NotImplementedError
|
66
|
-
def shrink(self, _): raise NotImplementedError
|
67
|
-
def stride(self, _): raise NotImplementedError
|
68
|
-
|
69
|
-
# **************** Device ****************
|
70
|
-
|
71
|
-
class _Device:
|
72
|
-
def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
73
|
-
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
|
74
|
-
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
75
|
-
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
|
76
|
-
x = x.split(":")[0].upper()
|
77
|
-
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
|
78
66
|
@functools.cached_property
|
79
|
-
def
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
class Interpreted:
|
92
|
-
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=lambda x: x.realized, to_underlying=lambda x: x._buf, from_underlying=None):
|
93
|
-
self.buffer, self.fxn_for_op, self.from_lazybuffer, self.to_underlying = buffer, fxn_for_op, from_lazybuffer, to_underlying
|
94
|
-
self.from_underlying = buffer if from_underlying is None else from_underlying
|
95
|
-
self.synchronize = lambda: None
|
96
|
-
self.codegen = None
|
97
|
-
|
98
|
-
def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
|
99
|
-
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
100
|
-
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
101
|
-
created_context = context is None
|
102
|
-
if context is None: context = dict()
|
103
|
-
if not created_context and ast in context: return context[ast]
|
104
|
-
srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
|
105
|
-
if DEBUG >= 3: st = time.perf_counter()
|
106
|
-
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
107
|
-
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
108
|
-
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
|
109
|
-
if not created_context: context[ast] = ret
|
110
|
-
if output is not None and output.output_buffer is not None:
|
111
|
-
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
|
112
|
-
output.output_buffer._buf = ret._buf
|
113
|
-
return output.output_buffer
|
114
|
-
return ret
|
67
|
+
def key(self) -> bytes:
|
68
|
+
return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
|
69
|
+
@functools.cached_property
|
70
|
+
def hash(self): return hash((self.op, self.src, self.arg))
|
71
|
+
def __hash__(self): return self.hash
|
72
|
+
@functools.cached_property
|
73
|
+
def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
|
74
|
+
def vars(self) -> List[Variable]:
|
75
|
+
extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
|
76
|
+
const_vars = [x.arg.val.unbind()[0] for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
|
77
|
+
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr))
|
115
78
|
|
116
|
-
#
|
79
|
+
# **************** independent FlopCounter ****************
|
117
80
|
|
81
|
+
@dataclass
|
118
82
|
class FlopCounter:
|
119
|
-
|
83
|
+
shape: Tuple[int, ...]
|
84
|
+
flops: sint
|
85
|
+
mem: Dict[int, int]
|
86
|
+
@property
|
87
|
+
def mem_estimate(self): return sum(self.mem.values())
|
120
88
|
def consume_flops(self):
|
121
89
|
self.flops, ret = 0, self.flops
|
122
90
|
return ret
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
def
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, runtime, synchronize
|
170
|
-
self.method_cache: Dict[Any, ASTRunner] = {}
|
171
|
-
|
172
|
-
def to_program(self, k):
|
173
|
-
k.linearize()
|
174
|
-
ret = self.renderer(k.function_name, k.uops)
|
175
|
-
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
|
176
|
-
return ASTRunner(k.function_name, src, global_size, local_size,
|
177
|
-
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
|
178
|
-
display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime)
|
179
|
-
|
180
|
-
def exec_ast(self, ast:LazyOp, output, **kwargs):
|
181
|
-
# all movementops do nothing in a Compiled buffer!
|
182
|
-
if ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized: return ast.src[0].realized
|
183
|
-
|
184
|
-
# check if we can reuse the output buffer
|
185
|
-
# if it's aliased, don't use it
|
186
|
-
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
|
187
|
-
output.realized = output.output_buffer
|
188
|
-
if output.realized:
|
189
|
-
if output.realized.__class__ is RawConst: output.realized = None # can't assign to RawConst
|
190
|
-
for a in ast.buffers:
|
191
|
-
if a.realized == output.realized and not a.st.contiguous:
|
192
|
-
output.realized = None
|
193
|
-
break
|
194
|
-
|
195
|
-
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
196
|
-
if not output.realized: output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
|
197
|
-
# update the output var_vals from src
|
198
|
-
output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
199
|
-
|
200
|
-
from tinygrad.codegen.linearizer import Linearizer
|
201
|
-
k = Linearizer(ast, output, self.linearizer_opts)
|
202
|
-
|
203
|
-
# compilation time
|
204
|
-
def get_program():
|
205
|
-
from tinygrad.codegen.search import kernel_optimize
|
206
|
-
if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, output, self.linearizer_opts), self.to_program)
|
207
|
-
elif not getenv("NOOPT"): k.hand_coded_optimizations()
|
208
|
-
return self.to_program(k)
|
209
|
-
|
210
|
-
if hasattr(k, 'key') and getenv("ENABLE_METHOD_CACHE", 1):
|
211
|
-
if k.key not in self.method_cache: self.method_cache[k.key] = get_program()
|
212
|
-
prg = self.method_cache[k.key]
|
213
|
-
else:
|
214
|
-
prg = get_program()
|
215
|
-
|
216
|
-
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
217
|
-
|
218
|
-
prg.exec(k.bufs, var_vals=output.st.var_vals)
|
219
|
-
return output.realized
|
91
|
+
|
92
|
+
InterpretedFlopCounter: Dict[Op, Callable] = {
|
93
|
+
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
94
|
+
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
|
95
|
+
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
96
|
+
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
|
97
|
+
UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
|
98
|
+
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
|
99
|
+
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
100
|
+
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
|
101
|
+
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
102
|
+
|
103
|
+
@functools.lru_cache(None)
|
104
|
+
def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
105
|
+
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
|
106
|
+
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
|
107
|
+
return run_ast(ast)
|
108
|
+
|
109
|
+
# **************** ops in python ****************
|
110
|
+
|
111
|
+
def hook_overflow(dv, fxn):
|
112
|
+
def wfxn(*args):
|
113
|
+
try: return fxn(*args)
|
114
|
+
except OverflowError: return dv
|
115
|
+
return wfxn
|
116
|
+
|
117
|
+
python_alu = {
|
118
|
+
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
|
119
|
+
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
|
120
|
+
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
|
121
|
+
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
122
|
+
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
|
123
|
+
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
|
124
|
+
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
|
125
|
+
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
|
126
|
+
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
127
|
+
|
128
|
+
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
129
|
+
# TODO: float16 and bfloat16?
|
130
|
+
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
131
|
+
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
132
|
+
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
133
|
+
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
|
134
|
+
dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
|
135
|
+
|
136
|
+
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|