tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/__init__.py
CHANGED
@@ -4,7 +4,7 @@ if int(os.getenv("TYPED", "0")):
|
|
4
4
|
install_import_hook(__name__)
|
5
5
|
from tinygrad.tensor import Tensor # noqa: F401
|
6
6
|
from tinygrad.engine.jit import TinyJit # noqa: F401
|
7
|
-
from tinygrad.ops import UOp
|
7
|
+
from tinygrad.uop.ops import UOp
|
8
8
|
Variable = UOp.variable
|
9
9
|
from tinygrad.dtype import dtypes # noqa: F401
|
10
10
|
from tinygrad.helpers import GlobalCounters, fetch, Context, getenv # noqa: F401
|
tinygrad/apps/llm.py
ADDED
@@ -0,0 +1,206 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import sys, argparse, typing, re, itertools, unicodedata
|
3
|
+
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers
|
4
|
+
|
5
|
+
def gpt2_decode_vocab(voc: dict[str, int]): # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
6
|
+
c2b = { chr(cp): cp for cp in itertools.chain(range(ord("!"), ord("~")+1), range(ord("¡"), ord("¬")+1), range(ord("®"), ord("ÿ")+1)) }
|
7
|
+
c2b.update({ chr(256+off): cp for off, cp in enumerate(cp for cp in range(256) if chr(cp) not in c2b) })
|
8
|
+
return { bytes(c2b[c] for c in tok): tid for tok, tid in voc.items() }
|
9
|
+
|
10
|
+
def get_llama_re():
|
11
|
+
def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre))
|
12
|
+
r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L")
|
13
|
+
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286
|
14
|
+
return "(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \
|
15
|
+
f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+"
|
16
|
+
|
17
|
+
class SimpleTokenizer:
|
18
|
+
def __init__(self, pat: str, normal_tokens: dict[bytes, int], special_tokens: dict[str, int]):
|
19
|
+
self._normal_tokens, self._special_tokens, self._pat = normal_tokens, special_tokens, re.compile(pat)
|
20
|
+
self._tok2str = { tid: tok.encode() for tok, tid in special_tokens.items() } | { tid: tok for tok, tid in normal_tokens.items() }
|
21
|
+
self._special_re = re.compile("|".join(re.escape(tok) for tok in self._special_tokens.keys()) if special_tokens else r"(?!)")
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def from_gguf_kv(kv: dict):
|
25
|
+
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
26
|
+
if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
|
27
|
+
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
28
|
+
normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
29
|
+
return SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens))
|
30
|
+
|
31
|
+
def encode(self, text: str):
|
32
|
+
tokens: list[int] = []
|
33
|
+
pos = 0
|
34
|
+
for match in self._special_re.finditer(text):
|
35
|
+
tokens.extend(self._encode_sentence(text[pos:match.start(0)]) + [self._special_tokens[text[match.start(0):match.end(0)]]])
|
36
|
+
pos = match.end(0)
|
37
|
+
return tokens + self._encode_sentence(text[pos:])
|
38
|
+
|
39
|
+
def decode(self, ids: list[int]) -> str: return b''.join(self._tok2str[tid] for tid in ids).decode()
|
40
|
+
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
41
|
+
|
42
|
+
def _encode_sentence(self, chunk: str): return [ tok for word in self._pat.findall(chunk) for tok in self._encode_word(word.encode()) ]
|
43
|
+
def _encode_word(self, word: bytes):
|
44
|
+
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
|
45
|
+
parts = [word[i:i+1] for i in range(len(word))]
|
46
|
+
while True:
|
47
|
+
min_tid, min_idx = 2**32, -1
|
48
|
+
for idx, (p1, p2) in enumerate(zip(parts[:-1], parts[1:])):
|
49
|
+
tid = self._normal_tokens.get(p1 + p2, min_tid)
|
50
|
+
if tid < min_tid: min_tid, min_idx = tid, idx
|
51
|
+
if min_idx == -1: break
|
52
|
+
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx+1]] + parts[min_idx+2:]
|
53
|
+
try: return [ self._normal_tokens[p] for p in parts ]
|
54
|
+
except KeyError: raise RuntimeError("token not found")
|
55
|
+
|
56
|
+
def apply_rope(x:Tensor, start_pos:int|UOp, base:int=10000):
|
57
|
+
B, H, T, Hd = x.shape
|
58
|
+
# NOTE: this is usually in a RoPE cache, but tinygrad JIT should prune it outside the kernel
|
59
|
+
# TODO: make it do that
|
60
|
+
freq = base ** (-Tensor.arange(0, 1, 2/Hd, dtype='float32'))
|
61
|
+
angles = Tensor.arange(start_pos, start_pos+T, dtype='float32')[None, None, :, None] * freq
|
62
|
+
cos, sin = angles.cos(), angles.sin()
|
63
|
+
x = x.reshape(B, H, T, Hd // 2, 2) # split into pairs
|
64
|
+
y1 = x[..., 0] * cos - x[..., 1] * sin
|
65
|
+
y2 = x[..., 0] * sin + x[..., 1] * cos
|
66
|
+
return Tensor.stack(y1, y2, dim=-1).reshape(B, H, T, Hd)
|
67
|
+
|
68
|
+
class TransformerBlock:
|
69
|
+
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0):
|
70
|
+
self.n_heads = n_heads
|
71
|
+
self.n_kv_heads = n_kv_heads
|
72
|
+
self.head_dim = dim // n_heads
|
73
|
+
self.max_context = max_context
|
74
|
+
|
75
|
+
# --- attention projections (all linear, bias-free) ------------------
|
76
|
+
kv_proj_out = self.head_dim * n_kv_heads # Llama-3 uses the same dim for K/V
|
77
|
+
self.attn_q = nn.Linear(dim, dim, bias=False)
|
78
|
+
self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
|
79
|
+
self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
|
80
|
+
self.attn_output = nn.Linear(dim, dim, bias=False)
|
81
|
+
|
82
|
+
# --- RMSNorms --------------------------------------------------------
|
83
|
+
self.attn_norm = nn.RMSNorm(dim, norm_eps)
|
84
|
+
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
85
|
+
|
86
|
+
# --- feed-forward ----------------------------------------------------
|
87
|
+
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
88
|
+
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
89
|
+
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
90
|
+
|
91
|
+
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
92
|
+
x_norm = self.attn_norm(x) # (B,T,D)
|
93
|
+
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
94
|
+
|
95
|
+
B, T, _ = x.shape
|
96
|
+
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
97
|
+
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
98
|
+
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
99
|
+
|
100
|
+
q = apply_rope(q, start_pos)
|
101
|
+
k = apply_rope(k, start_pos)
|
102
|
+
|
103
|
+
# TODO: remove these kv cache realizes
|
104
|
+
if not hasattr(self, "cache_kv"):
|
105
|
+
self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize()
|
106
|
+
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() # type: ignore
|
107
|
+
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
108
|
+
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
109
|
+
|
110
|
+
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
111
|
+
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None
|
112
|
+
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
113
|
+
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
114
|
+
attn = self.attn_output(attn)
|
115
|
+
return x + attn
|
116
|
+
|
117
|
+
def _feed_forward(self, h: Tensor) -> Tensor:
|
118
|
+
h_norm = self.ffn_norm(h)
|
119
|
+
gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm)
|
120
|
+
return h + self.ffn_down(gated)
|
121
|
+
|
122
|
+
def __call__(self, x: Tensor, start_pos: int|UOp):
|
123
|
+
return self._feed_forward(self._attention(x, start_pos))
|
124
|
+
|
125
|
+
class Transformer:
|
126
|
+
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, max_context):
|
127
|
+
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context) for _ in range(num_blocks)]
|
128
|
+
self.token_embd = nn.Embedding(vocab_size, dim)
|
129
|
+
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
130
|
+
self.output = nn.Linear(dim, vocab_size, bias=False)
|
131
|
+
self.max_context = max_context
|
132
|
+
# JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp
|
133
|
+
self.forward_jit = TinyJit(self.forward)
|
134
|
+
|
135
|
+
def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
|
136
|
+
x = self.token_embd(tokens) # (B, T, D)
|
137
|
+
for block in self.blk: x = block(x, start_pos)
|
138
|
+
# TODO: add temperature
|
139
|
+
return self.output(self.output_norm(x))[:, -1, :].softmax(-1).argmax(-1, keepdim=True)
|
140
|
+
|
141
|
+
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
|
142
|
+
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
|
143
|
+
|
144
|
+
@staticmethod
|
145
|
+
def from_gguf(gguf:Tensor, max_context:int|None=None) -> tuple[Transformer, dict]:
|
146
|
+
# TODO: remove the need for copy to default device
|
147
|
+
kv, state_dict = nn.state.gguf_load(gguf.to(None))
|
148
|
+
|
149
|
+
# all state items should be float16, not float32
|
150
|
+
state_dict = {k:v.cast('float16') for k,v in state_dict.items()}
|
151
|
+
|
152
|
+
# some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight
|
153
|
+
if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight']
|
154
|
+
|
155
|
+
arch = kv['general.architecture']
|
156
|
+
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
157
|
+
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
|
158
|
+
n_heads=kv[f'{arch}.attention.head_count'], n_kv_heads=kv[f'{arch}.attention.head_count_kv'],
|
159
|
+
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
|
160
|
+
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
161
|
+
return model, kv
|
162
|
+
|
163
|
+
def generate(self, tokens:list[int], start_pos=0):
|
164
|
+
v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
|
165
|
+
start_pos = 0
|
166
|
+
t = Tensor([tokens[start_pos:]], dtype="int32")
|
167
|
+
self.forward_jit.reset() # TODO: why is this required? root cause the issue and make it not be needed
|
168
|
+
while len(tokens) < self.max_context:
|
169
|
+
t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
|
170
|
+
next_id = int(t.item())
|
171
|
+
tokens.append(next_id)
|
172
|
+
start_pos = len(tokens) - 1
|
173
|
+
yield next_id
|
174
|
+
|
175
|
+
models = {
|
176
|
+
"1B": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
177
|
+
"3B": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf",
|
178
|
+
"3B_f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf",
|
179
|
+
"8B": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
|
180
|
+
}
|
181
|
+
|
182
|
+
if __name__ == "__main__":
|
183
|
+
parser = argparse.ArgumentParser()
|
184
|
+
parser.add_argument("--size", choices=list(models.keys()), default=list(models.keys())[0], help="Model size")
|
185
|
+
parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
|
186
|
+
args = parser.parse_args()
|
187
|
+
|
188
|
+
# load the model
|
189
|
+
model, kv = Transformer.from_gguf(Tensor.from_url(models[args.size]), args.max_context)
|
190
|
+
|
191
|
+
# extract some metadata
|
192
|
+
tok = SimpleTokenizer.from_gguf_kv(kv)
|
193
|
+
bos_id: int = kv['tokenizer.ggml.bos_token_id']
|
194
|
+
eos_id: int = kv['tokenizer.ggml.eos_token_id']
|
195
|
+
|
196
|
+
ids: list[int] = [bos_id]
|
197
|
+
while 1:
|
198
|
+
start_pos = len(ids) - 1
|
199
|
+
try:
|
200
|
+
ids += tok.role("user") + tok.encode(input('>>> ')) + [eos_id] + tok.role("assistant")
|
201
|
+
except EOFError:
|
202
|
+
break
|
203
|
+
for next_id in model.generate(ids, start_pos):
|
204
|
+
sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n")
|
205
|
+
sys.stdout.flush()
|
206
|
+
if next_id == eos_id: break
|
tinygrad/codegen/__init__.py
CHANGED
@@ -0,0 +1,116 @@
|
|
1
|
+
from typing import Any, Callable
|
2
|
+
import functools
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
5
|
+
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp
|
6
|
+
from tinygrad.uop.spec import type_verify
|
7
|
+
from tinygrad.renderer import Renderer
|
8
|
+
|
9
|
+
# import all pattern matchers here
|
10
|
+
from tinygrad.codegen.lowerer import pm_lowerer, get_index
|
11
|
+
from tinygrad.codegen.quantize import pm_quant
|
12
|
+
from tinygrad.codegen.gpudims import pm_add_gpudims
|
13
|
+
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
|
14
|
+
from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
15
|
+
from tinygrad.codegen.expander import migrate_indexing, expander
|
16
|
+
from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
17
|
+
ReduceContext, correct_load_store, pm_render
|
18
|
+
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
19
|
+
from tinygrad.codegen.opt import pm_optimize
|
20
|
+
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class RewriteStep:
|
24
|
+
pm: PatternMatcher
|
25
|
+
ctx: Callable[[UOp], Any]|None = None
|
26
|
+
name: str|None = None
|
27
|
+
bottom_up: bool = False
|
28
|
+
def __call__(self, sink:UOp):
|
29
|
+
return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
|
30
|
+
|
31
|
+
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
|
32
|
+
|
33
|
+
rewrites_for_views = [
|
34
|
+
RewriteStep(view_left, name="Main View Left"),
|
35
|
+
RewriteStep(view_right, name="Main View Right"),
|
36
|
+
RewriteStep(view_left+fix_kernel_ops, bottom_up=True, name="Finalize Kernel"),
|
37
|
+
]
|
38
|
+
|
39
|
+
rewrites_for_linearizer = [
|
40
|
+
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
|
41
|
+
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
|
42
|
+
RewriteStep(block_merge, name="Linearizer: Merge Blocks"),
|
43
|
+
RewriteStep(pm_finalize, name="Linearizer: Finalize")]
|
44
|
+
|
45
|
+
def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]:
|
46
|
+
# cache with the values of the context vars
|
47
|
+
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
|
48
|
+
|
49
|
+
@functools.cache
|
50
|
+
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
51
|
+
# ** lowerer (rewrite_shapetracker_with_index) **
|
52
|
+
ret: list[RewriteStep] = []
|
53
|
+
|
54
|
+
# view pushing
|
55
|
+
ret.extend(rewrites_for_views)
|
56
|
+
|
57
|
+
# this is kernel.py
|
58
|
+
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
59
|
+
|
60
|
+
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
61
|
+
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
62
|
+
|
63
|
+
# ** expander (expand_rewrite) **
|
64
|
+
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
|
65
|
+
|
66
|
+
# expand
|
67
|
+
ret.append(RewriteStep(sym+expander, name="expander"))
|
68
|
+
|
69
|
+
# ** devectorizer (full_graph_rewrite) **
|
70
|
+
# remove reduce
|
71
|
+
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
|
72
|
+
|
73
|
+
# add gpu dims (late)
|
74
|
+
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
|
75
|
+
|
76
|
+
# devectorize (TODO: does this need opts?)
|
77
|
+
if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
78
|
+
elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
79
|
+
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
80
|
+
ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
|
81
|
+
|
82
|
+
supported_ops = tuple(opts.code_for_op.keys())
|
83
|
+
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
84
|
+
|
85
|
+
# optional pre matcher
|
86
|
+
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
|
87
|
+
|
88
|
+
# decompositions
|
89
|
+
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)
|
90
|
+
ret.append(RewriteStep(pm_decomp, lambda _: opts.device, name="decompositions"))
|
91
|
+
|
92
|
+
# final rules for the renderer (without sym)
|
93
|
+
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
|
94
|
+
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
|
95
|
+
|
96
|
+
# return the list (with optional linearizer)
|
97
|
+
return ret + (rewrites_for_linearizer if linearizer else [])
|
98
|
+
|
99
|
+
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp:
|
100
|
+
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer))
|
101
|
+
|
102
|
+
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
|
103
|
+
"""
|
104
|
+
Function to transform the Kernel UOp graph into a linearized program.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
sink: The Ops.SINK rooting the Kernel graph.
|
108
|
+
opts: The Renderer (can change how things are processed, fix this).
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
Linear program in UOps.
|
112
|
+
"""
|
113
|
+
|
114
|
+
lst = list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)
|
115
|
+
if __debug__: type_verify(lst)
|
116
|
+
return lst
|