tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/nn/state.py CHANGED
@@ -1,37 +1,80 @@
1
- import os, json, pathlib, zipfile, pickle
1
+ import os, json, pathlib, zipfile, pickle, tarfile, struct
2
2
  from tqdm import tqdm
3
- from typing import Dict, Union, List
3
+ from typing import Dict, Union, List, Optional, Any, Tuple
4
4
  from tinygrad.tensor import Tensor
5
- from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
6
- from tinygrad.shape.shapetracker import strides_for_shape
7
- from tinygrad.ops import Device
5
+ from tinygrad.dtype import dtypes
6
+ from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters
7
+ from tinygrad.shape.view import strides_for_shape
8
+ from tinygrad.multi import MultiLazyBuffer
8
9
 
9
- safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
10
+ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
11
+ "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
10
12
  inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
11
13
 
12
- def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
14
+ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
15
+ """
16
+ Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
17
+ """
13
18
  t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
14
- json_len = t[0:1].cast(dtypes.int64).numpy()[0]
15
- metadata = json.loads(t[8:8+json_len].numpy().tobytes())
16
- return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
19
+ json_len = t[0:8].bitcast(dtypes.int64).item()
20
+ return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
21
+
22
+ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
23
+ """
24
+ Loads a .safetensor file from disk, returning the state_dict.
17
25
 
18
- def safe_save(tensors:Dict[str, Tensor], fn:str):
19
- metadata, offset = {}, 0
26
+ ```python
27
+ state_dict = nn.state.safe_load("test.safetensor")
28
+ ```
29
+ """
30
+ t, json_len, metadata = safe_load_metadata(fn)
31
+ ret = {}
32
+ for k,v in metadata.items():
33
+ if k == "__metadata__": continue
34
+ dtype = safe_dtypes[v['dtype']]
35
+ sz = (v['data_offsets'][1]-v['data_offsets'][0])
36
+ ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
37
+ return ret
38
+
39
+ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
40
+ """
41
+ Saves a state_dict to disk in a .safetensor file with optional metadata.
42
+
43
+ ```python
44
+ t = nn.Tensor([1, 2, 3])
45
+ nn.state.safe_save({'t':t}, "test.safetensor")
46
+ ```
47
+ """
48
+ headers, offset = {}, 0
49
+ if metadata: headers['__metadata__'] = metadata
20
50
  for k,v in tensors.items():
21
- metadata[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
51
+ headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
22
52
  offset += v.nbytes()
23
- j = json.dumps(metadata, separators=(',', ':'))
53
+ j = json.dumps(headers, separators=(',', ':'))
24
54
  j += "\x20"*((8-len(j)%8)%8)
25
55
  pathlib.Path(fn).unlink(missing_ok=True)
26
56
  t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
27
- t[0:1].cast(dtypes.int64).assign([len(j)])
28
- t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu"))
57
+ t[0:8].bitcast(dtypes.int64).assign([len(j)])
58
+ t[8:8+len(j)].assign(list(j.encode('utf-8')))
29
59
  for k,v in safe_load(t).items(): v.assign(tensors[k])
30
60
 
31
61
  # state dict
32
62
 
33
63
  from collections import OrderedDict
34
64
  def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
65
+ """
66
+ Returns a state_dict of the object, with optional prefix.
67
+
68
+ ```python exec="true" source="above" session="tensor" result="python"
69
+ class Net:
70
+ def __init__(self):
71
+ self.l1 = nn.Linear(4, 5)
72
+ self.l2 = nn.Linear(5, 6)
73
+
74
+ net = Net()
75
+ print(nn.state.get_state_dict(net).keys())
76
+ ```
77
+ """
35
78
  if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
36
79
  if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
37
80
  if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
@@ -42,39 +85,71 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
42
85
  elif isinstance(obj, dict):
43
86
  for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
44
87
  return state_dict
45
- def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
88
+ def get_parameters(obj) -> List[Tensor]:
89
+ """
90
+ ```python exec="true" source="above" session="tensor" result="python"
91
+ class Net:
92
+ def __init__(self):
93
+ self.l1 = nn.Linear(4, 5)
94
+ self.l2 = nn.Linear(5, 6)
95
+
96
+ net = Net()
97
+ print(len(nn.state.get_parameters(net)))
98
+ ```
99
+ """
100
+ return list(get_state_dict(obj).values())
101
+
102
+ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
103
+ """
104
+ Loads a state_dict into a model.
46
105
 
47
- def load_state_dict(model, state_dict, strict=True):
48
- with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
106
+ ```python
107
+ class Net:
108
+ def __init__(self):
109
+ self.l1 = nn.Linear(4, 5)
110
+ self.l2 = nn.Linear(5, 6)
111
+
112
+ net = Net()
113
+ state_dict = nn.state.get_state_dict(net)
114
+ nn.state.load_state_dict(net, state_dict)
115
+ ```
116
+ """
117
+ start_mem_used = GlobalCounters.mem_used
118
+ with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501
49
119
  model_state_dict = get_state_dict(model)
50
- if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
51
- for k,v in (t := tqdm(model_state_dict.items(), disable=CI)):
120
+ if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
121
+ print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
122
+ for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
52
123
  t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
53
124
  if k not in state_dict and not strict:
54
125
  if DEBUG >= 1: print(f"WARNING: not loading {k}")
55
126
  continue
56
- v.assign(state_dict[k].to(v.device)).realize()
127
+ if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
128
+ if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
129
+ else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
130
+ else: v.replace(state_dict[k].to(v.device)).realize()
131
+ if consume: del state_dict[k]
57
132
 
58
133
  # torch support!
59
134
 
60
- def torch_load(fn:str):
135
+ def torch_load(fn:str) -> Dict[str, Tensor]:
136
+ """
137
+ Loads a torch .pth file from disk.
138
+
139
+ ```python
140
+ state_dict = nn.state.torch_load("test.pth")
141
+ ```
142
+ """
61
143
  t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
62
144
 
63
- offsets: Dict[str, int] = {}
64
- lens: Dict[str, int] = {}
65
- def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
145
+ offsets: Dict[Union[str, int], int] = {}
146
+ lens: Dict[Union[str, int], int] = {}
147
+ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
66
148
  #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
67
149
  lens[storage[2]] = storage[4] * storage[1].itemsize
68
150
  if storage[2] not in offsets: return None
69
151
  byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
70
- ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
71
- # convert bfloat16 -> float16 using LLVM for Llama 2
72
- # upstream LLaMA also does this conversion:
73
- # https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
74
- # TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
75
- if storage[1] == dtypes.bfloat16:
76
- ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
77
- #ret = ret.to("LLVM").half().to(Device.DEFAULT)
152
+ ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
78
153
 
79
154
  # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
80
155
  shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
@@ -82,13 +157,20 @@ def torch_load(fn:str):
82
157
  if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
83
158
  intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
84
159
  assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
85
- if DEBUG >= 2: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}")
160
+ if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
161
+ assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
86
162
  # TODO: find a nice way to support all shapetracker on disktensors
87
- ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
163
+ # TODO: BUG: a ".realize()" is needed here for 'GPU=1 python3 test/models/test_efficientnet.py TestEfficientNet.test_car'
164
+ ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes).realize()
88
165
 
89
166
  return ret.reshape(size)
90
167
 
91
- intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
168
+ class Parameter:
169
+ def __setstate__(self, state): self.tensor = state[0]
170
+
171
+ deserialized_objects: Dict[str, Any] = {}
172
+ intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32,
173
+ "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
92
174
  whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
93
175
  class Dummy: pass
94
176
  class TorchPickle(pickle.Unpickler):
@@ -98,9 +180,9 @@ def torch_load(fn:str):
98
180
  if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
99
181
  return Dummy
100
182
  return intercept[name] if module_root == "torch" else super().find_class(module, name)
101
- def persistent_load(self, pid): return pid
183
+ def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
102
184
 
103
- if tuple(t[0:2].numpy()) == (0x50, 0x4b):
185
+ if zipfile.is_zipfile(fn):
104
186
  myzip = zipfile.ZipFile(fn, 'r')
105
187
  base_name = myzip.namelist()[0].split('/', 1)[0]
106
188
  for n in myzip.namelist():
@@ -109,6 +191,21 @@ def torch_load(fn:str):
109
191
  offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
110
192
  with myzip.open(f'{base_name}/data.pkl') as myfile:
111
193
  return TorchPickle(myfile).load()
194
+ elif tarfile.is_tarfile(fn):
195
+ with tarfile.open(fn, "r") as tar:
196
+ storages_offset = tar.getmember('storages').offset_data
197
+ f = unwrap(tar.extractfile('storages'))
198
+ for i in range(TorchPickle(f).load()): # num_storages
199
+ (key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
200
+ offsets[key] = storages_offset + f.tell()
201
+ f.seek(sz*storage_type.itemsize, 1)
202
+ f = unwrap(tar.extractfile('tensors'))
203
+ for _ in range(TorchPickle(f).load()): # num_tensors
204
+ (key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
205
+ size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
206
+ storage_offset = struct.unpack('<q', f.read(8))[0]
207
+ deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
208
+ return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
112
209
  else:
113
210
  with open(fn, "rb") as f:
114
211
  pkl = TorchPickle(f)
tinygrad/ops.py CHANGED
@@ -1,219 +1,136 @@
1
1
  from __future__ import annotations
2
- import time, importlib, inspect, functools, pathlib
2
+ from typing import Union, Tuple, Any, List, Dict, Callable
3
+ import functools, hashlib, math, operator, ctypes
3
4
  from enum import Enum, auto
4
- from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
5
- from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
6
- if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
5
+ from dataclasses import dataclass
6
+ from tinygrad.helpers import prod, dedup
7
+ from tinygrad.dtype import dtypes, DType, ConstType
8
+ from tinygrad.shape.symbolic import Variable, sint
9
+ from tinygrad.shape.shapetracker import ShapeTracker
7
10
 
8
11
  # these are the llops your accelerator must implement, along with toCpu
9
12
  # the Enum class doesn't work with mypy, this is static. sorry it's ugly
10
13
  # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
11
- # NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
12
- class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702
13
- class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
14
- class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
15
- class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
16
- class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
17
- class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
18
-
19
- Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps]
20
- OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps]]
21
-
14
+ # NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
15
+ class UnaryOps(Enum):
16
+ """A -> A (elementwise)"""
17
+ EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
18
+ class BinaryOps(Enum):
19
+ """A + A -> A (elementwise)"""
20
+ ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
21
+ class TernaryOps(Enum):
22
+ """A + A + A -> A (elementwise)"""
23
+ WHERE = auto(); MULACC = auto() # noqa: E702
24
+ class ReduceOps(Enum):
25
+ """A -> B (reduce)"""
26
+ SUM = auto(); MAX = auto() # noqa: E702
27
+ class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
28
+ class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
29
+
30
+ Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
31
+
32
+ # do not preserve f(0) = 0
33
+ UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
34
+
35
+ @dataclass(frozen=True)
36
+ class MemBuffer:
37
+ idx: int
38
+ dtype: DType
39
+ st: ShapeTracker
40
+
41
+ @dataclass(frozen=True)
42
+ class ConstBuffer:
43
+ val: ConstType
44
+ dtype: DType
45
+ st: ShapeTracker
46
+
47
+ @dataclass(frozen=True, eq=False)
22
48
  class LazyOp:
23
- __slots__ = "op", "src", "arg", "buffers", "__weakref__"
24
49
  op: Op
25
- src: Tuple[Union[LazyOp, LazyBuffer], ...]
26
- arg: Any
27
- buffers: Tuple[LazyBuffer, ...]
28
- def __init__(self, op: Op, src: Tuple[Union[LazyOp, LazyBuffer], ...], arg: Any = None):
29
- self.op, self.src, self.arg, self.buffers = op, src, arg, ()
30
- try: # NOTE: the linearizer's key function maps the buffers to ints, and LOCAL_BUFFER is used. we don't care about buffers in these cases
31
- for x in src: self.buffers += x.buffers
32
- except AttributeError: self.buffers = ()
33
-
50
+ src: Tuple[LazyOp, ...] = ()
51
+ arg: Any = None
52
+ def cached_compare(self, x, context):
53
+ if id(self) == id(x): return True
54
+ if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False
55
+ if (key := (id(self), id(x))) in context: return context[key]
56
+ ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
57
+ return ret
58
+ def __eq__(self, x): return self.cached_compare(x, context={})
34
59
  def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
35
- def __eq__(self, __value: object) -> bool: return isinstance(__value, LazyOp) and self.op is __value.op and self.src == __value.src and self.arg == __value.arg
36
- def __hash__(self) -> int: return hash((self.op, self.src, self.arg))
37
- @property
38
- def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
39
-
40
- # Any == Union[LazyBuffer, DeviceBuffer]
41
- def map_buffers(self, real_srcs: Dict[Any, Any]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
42
- def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
43
-
44
- def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
45
- assert self.op in BinaryOps or self.op in UnaryOps or self.op in TernaryOps
46
- srcs = [z.replace_with_movement_ops(ops) for z in self.src]
47
- return srcs[0].e(self.op, *srcs[1:], arg=self.arg) # type: ignore
60
+ @functools.cached_property
61
+ def dtype(self) -> DType:
62
+ if self.op in BufferOps: return self.arg.dtype
63
+ if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
64
+ return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else self.src[-1].dtype
48
65
 
49
- @property
50
- def st(self): raise NotImplementedError
51
- @property
52
- def children(self): raise NotImplementedError
53
- @property
54
- def shape(self): raise NotImplementedError
55
- @property
56
- def realized(self): raise NotImplementedError
57
- @property
58
- def optype(self): raise NotImplementedError
59
- def realize(self): raise NotImplementedError
60
-
61
- # movement ops
62
- def reshape(self, _): raise NotImplementedError
63
- def pad(self, _): raise NotImplementedError
64
- def expand(self, _): raise NotImplementedError
65
- def permute(self, _): raise NotImplementedError
66
- def shrink(self, _): raise NotImplementedError
67
- def stride(self, _): raise NotImplementedError
68
-
69
- # **************** Device ****************
70
-
71
- class _Device:
72
- def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
73
- def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
74
- @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
75
- def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
76
- x = x.split(":")[0].upper()
77
- return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
78
66
  @functools.cached_property
79
- def DEFAULT(self) -> str:
80
- device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None)
81
- if device_from_env: return device_from_env
82
- for device in ["METAL", "CUDA", "GPU"]:
83
- try:
84
- if self[device]: return device
85
- except Exception: pass
86
- return "CPU"
87
- Device = _Device()
88
-
89
- # **************** for Interpreted Buffers ****************
90
-
91
- class Interpreted:
92
- def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=lambda x: x.realized, to_underlying=lambda x: x._buf, from_underlying=None):
93
- self.buffer, self.fxn_for_op, self.from_lazybuffer, self.to_underlying = buffer, fxn_for_op, from_lazybuffer, to_underlying
94
- self.from_underlying = buffer if from_underlying is None else from_underlying
95
- self.synchronize = lambda: None
96
- self.codegen = None
97
-
98
- def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
99
- if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
100
- ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
101
- created_context = context is None
102
- if context is None: context = dict()
103
- if not created_context and ast in context: return context[ast]
104
- srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
105
- if DEBUG >= 3: st = time.perf_counter()
106
- ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
107
- if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
108
- if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
109
- if not created_context: context[ast] = ret
110
- if output is not None and output.output_buffer is not None:
111
- assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
112
- output.output_buffer._buf = ret._buf
113
- return output.output_buffer
114
- return ret
67
+ def key(self) -> bytes:
68
+ return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
69
+ @functools.cached_property
70
+ def hash(self): return hash((self.op, self.src, self.arg))
71
+ def __hash__(self): return self.hash
72
+ @functools.cached_property
73
+ def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
74
+ def vars(self) -> List[Variable]:
75
+ extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
76
+ const_vars = [x.arg.val.unbind()[0] for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
77
+ return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr))
115
78
 
116
- # --teenygrad--
79
+ # **************** independent FlopCounter ****************
117
80
 
81
+ @dataclass
118
82
  class FlopCounter:
119
- def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops, self._buf = *tup, self
83
+ shape: Tuple[int, ...]
84
+ flops: sint
85
+ mem: Dict[int, int]
86
+ @property
87
+ def mem_estimate(self): return sum(self.mem.values())
120
88
  def consume_flops(self):
121
89
  self.flops, ret = 0, self.flops
122
90
  return ret
123
- shape_fxn_for_op: Dict[Op, Callable] = {
124
- UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
125
- **{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
126
- **{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
127
- **{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
128
- TernaryOps.WHERE: lambda self,y,z: (self.shape, self.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape))}
129
- InterpretedFlopCounter = Interpreted(FlopCounter, shape_fxn_for_op, lambda x: FlopCounter((x.shape, x.dtype, 0)), lambda x: x)
130
- def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.exec_ast(ast)
131
-
132
- # **************** for Compiled Buffers ****************
133
-
134
- from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
135
- from tinygrad.shape.symbolic import Variable, sym_infer
136
-
137
- class ASTRunner:
138
- def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
139
- if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
140
- self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
141
-
142
- def build(self, runtime):
143
- self.clprg = runtime(self.name, self.prg, **self.runtime_args)
144
- return self
145
-
146
- def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
147
- rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
148
- if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs, var_vals if var_vals is not None else {}))
149
- return self(rawbufs, var_vals, force_wait=force_wait)
150
-
151
- def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
152
- if var_vals is None: var_vals = {}
153
- global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
154
- local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
155
- if et := self.clprg((global_size + [1]*(3-len(global_size))) if global_size is not None else None,
156
- (local_size + [1]*(3-len(local_size))) if local_size is not None else None,
157
- *rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
158
- op_estimate = sym_infer(self.op_estimate, var_vals)
159
- if DEBUG >= 2:
160
- print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
161
- (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
162
- GlobalCounters.kernel_count += 1
163
- GlobalCounters.global_ops += op_estimate
164
- GlobalCounters.global_mem += self.mem_estimate
165
- return et
166
-
167
- class Compiled:
168
- def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime, synchronize=lambda: None):
169
- self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, runtime, synchronize
170
- self.method_cache: Dict[Any, ASTRunner] = {}
171
-
172
- def to_program(self, k):
173
- k.linearize()
174
- ret = self.renderer(k.function_name, k.uops)
175
- src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
176
- return ASTRunner(k.function_name, src, global_size, local_size,
177
- op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
178
- display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime)
179
-
180
- def exec_ast(self, ast:LazyOp, output, **kwargs):
181
- # all movementops do nothing in a Compiled buffer!
182
- if ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized: return ast.src[0].realized
183
-
184
- # check if we can reuse the output buffer
185
- # if it's aliased, don't use it
186
- # NOTE: this is pretty wrong actually, who knows where else this buffer is used?
187
- output.realized = output.output_buffer
188
- if output.realized:
189
- if output.realized.__class__ is RawConst: output.realized = None # can't assign to RawConst
190
- for a in ast.buffers:
191
- if a.realized == output.realized and not a.st.contiguous:
192
- output.realized = None
193
- break
194
-
195
- # we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
196
- if not output.realized: output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
197
- # update the output var_vals from src
198
- output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
199
-
200
- from tinygrad.codegen.linearizer import Linearizer
201
- k = Linearizer(ast, output, self.linearizer_opts)
202
-
203
- # compilation time
204
- def get_program():
205
- from tinygrad.codegen.search import kernel_optimize
206
- if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, output, self.linearizer_opts), self.to_program)
207
- elif not getenv("NOOPT"): k.hand_coded_optimizations()
208
- return self.to_program(k)
209
-
210
- if hasattr(k, 'key') and getenv("ENABLE_METHOD_CACHE", 1):
211
- if k.key not in self.method_cache: self.method_cache[k.key] = get_program()
212
- prg = self.method_cache[k.key]
213
- else:
214
- prg = get_program()
215
-
216
- if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
217
-
218
- prg.exec(k.bufs, var_vals=output.st.var_vals)
219
- return output.realized
91
+
92
+ InterpretedFlopCounter: Dict[Op, Callable] = {
93
+ BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
94
+ BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
95
+ BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
96
+ UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
97
+ UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
98
+ **{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
99
+ **{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
100
+ **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
101
+ TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
102
+
103
+ @functools.lru_cache(None)
104
+ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
105
+ @functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
106
+ def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
107
+ return run_ast(ast)
108
+
109
+ # **************** ops in python ****************
110
+
111
+ def hook_overflow(dv, fxn):
112
+ def wfxn(*args):
113
+ try: return fxn(*args)
114
+ except OverflowError: return dv
115
+ return wfxn
116
+
117
+ python_alu = {
118
+ UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
119
+ UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
120
+ UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
121
+ UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
122
+ BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
123
+ BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
124
+ BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
125
+ BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
126
+ TernaryOps.WHERE: lambda x,y,z: y if x else z}
127
+
128
+ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
129
+ # TODO: float16 and bfloat16?
130
+ dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
131
+ dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
132
+ dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
133
+ dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
134
+ dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
135
+
136
+ def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))