tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/__init__.py CHANGED
@@ -4,7 +4,7 @@ if int(os.getenv("TYPED", "0")):
4
4
  install_import_hook(__name__)
5
5
  from tinygrad.tensor import Tensor # noqa: F401
6
6
  from tinygrad.engine.jit import TinyJit # noqa: F401
7
- from tinygrad.ops import UOp
7
+ from tinygrad.uop.ops import UOp
8
8
  Variable = UOp.variable
9
9
  from tinygrad.dtype import dtypes # noqa: F401
10
10
  from tinygrad.helpers import GlobalCounters, fetch, Context, getenv # noqa: F401
tinygrad/apps/llm.py ADDED
@@ -0,0 +1,206 @@
1
+ from __future__ import annotations
2
+ import sys, argparse, typing, re, itertools, unicodedata
3
+ from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers
4
+
5
+ def gpt2_decode_vocab(voc: dict[str, int]): # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
6
+ c2b = { chr(cp): cp for cp in itertools.chain(range(ord("!"), ord("~")+1), range(ord("¡"), ord("¬")+1), range(ord("®"), ord("ÿ")+1)) }
7
+ c2b.update({ chr(256+off): cp for off, cp in enumerate(cp for cp in range(256) if chr(cp) not in c2b) })
8
+ return { bytes(c2b[c] for c in tok): tid for tok, tid in voc.items() }
9
+
10
+ def get_llama_re():
11
+ def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre))
12
+ r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L")
13
+ # https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286
14
+ return "(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \
15
+ f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+"
16
+
17
+ class SimpleTokenizer:
18
+ def __init__(self, pat: str, normal_tokens: dict[bytes, int], special_tokens: dict[str, int]):
19
+ self._normal_tokens, self._special_tokens, self._pat = normal_tokens, special_tokens, re.compile(pat)
20
+ self._tok2str = { tid: tok.encode() for tok, tid in special_tokens.items() } | { tid: tok for tok, tid in normal_tokens.items() }
21
+ self._special_re = re.compile("|".join(re.escape(tok) for tok in self._special_tokens.keys()) if special_tokens else r"(?!)")
22
+
23
+ @staticmethod
24
+ def from_gguf_kv(kv: dict):
25
+ # https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
26
+ if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
27
+ vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
28
+ normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
29
+ return SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens))
30
+
31
+ def encode(self, text: str):
32
+ tokens: list[int] = []
33
+ pos = 0
34
+ for match in self._special_re.finditer(text):
35
+ tokens.extend(self._encode_sentence(text[pos:match.start(0)]) + [self._special_tokens[text[match.start(0):match.end(0)]]])
36
+ pos = match.end(0)
37
+ return tokens + self._encode_sentence(text[pos:])
38
+
39
+ def decode(self, ids: list[int]) -> str: return b''.join(self._tok2str[tid] for tid in ids).decode()
40
+ def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
41
+
42
+ def _encode_sentence(self, chunk: str): return [ tok for word in self._pat.findall(chunk) for tok in self._encode_word(word.encode()) ]
43
+ def _encode_word(self, word: bytes):
44
+ if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
45
+ parts = [word[i:i+1] for i in range(len(word))]
46
+ while True:
47
+ min_tid, min_idx = 2**32, -1
48
+ for idx, (p1, p2) in enumerate(zip(parts[:-1], parts[1:])):
49
+ tid = self._normal_tokens.get(p1 + p2, min_tid)
50
+ if tid < min_tid: min_tid, min_idx = tid, idx
51
+ if min_idx == -1: break
52
+ parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx+1]] + parts[min_idx+2:]
53
+ try: return [ self._normal_tokens[p] for p in parts ]
54
+ except KeyError: raise RuntimeError("token not found")
55
+
56
+ def apply_rope(x:Tensor, start_pos:int|UOp, base:int=10000):
57
+ B, H, T, Hd = x.shape
58
+ # NOTE: this is usually in a RoPE cache, but tinygrad JIT should prune it outside the kernel
59
+ # TODO: make it do that
60
+ freq = base ** (-Tensor.arange(0, 1, 2/Hd, dtype='float32'))
61
+ angles = Tensor.arange(start_pos, start_pos+T, dtype='float32')[None, None, :, None] * freq
62
+ cos, sin = angles.cos(), angles.sin()
63
+ x = x.reshape(B, H, T, Hd // 2, 2) # split into pairs
64
+ y1 = x[..., 0] * cos - x[..., 1] * sin
65
+ y2 = x[..., 0] * sin + x[..., 1] * cos
66
+ return Tensor.stack(y1, y2, dim=-1).reshape(B, H, T, Hd)
67
+
68
+ class TransformerBlock:
69
+ def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0):
70
+ self.n_heads = n_heads
71
+ self.n_kv_heads = n_kv_heads
72
+ self.head_dim = dim // n_heads
73
+ self.max_context = max_context
74
+
75
+ # --- attention projections (all linear, bias-free) ------------------
76
+ kv_proj_out = self.head_dim * n_kv_heads # Llama-3 uses the same dim for K/V
77
+ self.attn_q = nn.Linear(dim, dim, bias=False)
78
+ self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
79
+ self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
80
+ self.attn_output = nn.Linear(dim, dim, bias=False)
81
+
82
+ # --- RMSNorms --------------------------------------------------------
83
+ self.attn_norm = nn.RMSNorm(dim, norm_eps)
84
+ self.ffn_norm = nn.RMSNorm(dim, norm_eps)
85
+
86
+ # --- feed-forward ----------------------------------------------------
87
+ self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
88
+ self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
89
+ self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
90
+
91
+ def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
92
+ x_norm = self.attn_norm(x) # (B,T,D)
93
+ q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
94
+
95
+ B, T, _ = x.shape
96
+ q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
97
+ k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
98
+ v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
99
+
100
+ q = apply_rope(q, start_pos)
101
+ k = apply_rope(k, start_pos)
102
+
103
+ # TODO: remove these kv cache realizes
104
+ if not hasattr(self, "cache_kv"):
105
+ self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize()
106
+ self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() # type: ignore
107
+ k = self.cache_kv[0, :, :, 0:start_pos+T, :]
108
+ v = self.cache_kv[1, :, :, 0:start_pos+T, :]
109
+
110
+ # NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
111
+ mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None
112
+ attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
113
+ attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
114
+ attn = self.attn_output(attn)
115
+ return x + attn
116
+
117
+ def _feed_forward(self, h: Tensor) -> Tensor:
118
+ h_norm = self.ffn_norm(h)
119
+ gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm)
120
+ return h + self.ffn_down(gated)
121
+
122
+ def __call__(self, x: Tensor, start_pos: int|UOp):
123
+ return self._feed_forward(self._attention(x, start_pos))
124
+
125
+ class Transformer:
126
+ def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, max_context):
127
+ self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context) for _ in range(num_blocks)]
128
+ self.token_embd = nn.Embedding(vocab_size, dim)
129
+ self.output_norm = nn.RMSNorm(dim, norm_eps)
130
+ self.output = nn.Linear(dim, vocab_size, bias=False)
131
+ self.max_context = max_context
132
+ # JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp
133
+ self.forward_jit = TinyJit(self.forward)
134
+
135
+ def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
136
+ x = self.token_embd(tokens) # (B, T, D)
137
+ for block in self.blk: x = block(x, start_pos)
138
+ # TODO: add temperature
139
+ return self.output(self.output_norm(x))[:, -1, :].softmax(-1).argmax(-1, keepdim=True)
140
+
141
+ def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
142
+ return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
143
+
144
+ @staticmethod
145
+ def from_gguf(gguf:Tensor, max_context:int|None=None) -> tuple[Transformer, dict]:
146
+ # TODO: remove the need for copy to default device
147
+ kv, state_dict = nn.state.gguf_load(gguf.to(None))
148
+
149
+ # all state items should be float16, not float32
150
+ state_dict = {k:v.cast('float16') for k,v in state_dict.items()}
151
+
152
+ # some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight
153
+ if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight']
154
+
155
+ arch = kv['general.architecture']
156
+ max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
157
+ model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
158
+ n_heads=kv[f'{arch}.attention.head_count'], n_kv_heads=kv[f'{arch}.attention.head_count_kv'],
159
+ norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
160
+ nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
161
+ return model, kv
162
+
163
+ def generate(self, tokens:list[int], start_pos=0):
164
+ v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
165
+ start_pos = 0
166
+ t = Tensor([tokens[start_pos:]], dtype="int32")
167
+ self.forward_jit.reset() # TODO: why is this required? root cause the issue and make it not be needed
168
+ while len(tokens) < self.max_context:
169
+ t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
170
+ next_id = int(t.item())
171
+ tokens.append(next_id)
172
+ start_pos = len(tokens) - 1
173
+ yield next_id
174
+
175
+ models = {
176
+ "1B": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
177
+ "3B": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf",
178
+ "3B_f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf",
179
+ "8B": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
180
+ }
181
+
182
+ if __name__ == "__main__":
183
+ parser = argparse.ArgumentParser()
184
+ parser.add_argument("--size", choices=list(models.keys()), default=list(models.keys())[0], help="Model size")
185
+ parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
186
+ args = parser.parse_args()
187
+
188
+ # load the model
189
+ model, kv = Transformer.from_gguf(Tensor.from_url(models[args.size]), args.max_context)
190
+
191
+ # extract some metadata
192
+ tok = SimpleTokenizer.from_gguf_kv(kv)
193
+ bos_id: int = kv['tokenizer.ggml.bos_token_id']
194
+ eos_id: int = kv['tokenizer.ggml.eos_token_id']
195
+
196
+ ids: list[int] = [bos_id]
197
+ while 1:
198
+ start_pos = len(ids) - 1
199
+ try:
200
+ ids += tok.role("user") + tok.encode(input('>>> ')) + [eos_id] + tok.role("assistant")
201
+ except EOFError:
202
+ break
203
+ for next_id in model.generate(ids, start_pos):
204
+ sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n")
205
+ sys.stdout.flush()
206
+ if next_id == eos_id: break
@@ -0,0 +1,116 @@
1
+ from typing import Any, Callable
2
+ import functools
3
+ from dataclasses import dataclass
4
+ from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
5
+ from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp
6
+ from tinygrad.uop.spec import type_verify
7
+ from tinygrad.renderer import Renderer
8
+
9
+ # import all pattern matchers here
10
+ from tinygrad.codegen.lowerer import pm_lowerer, get_index
11
+ from tinygrad.codegen.quantize import pm_quant
12
+ from tinygrad.codegen.gpudims import pm_add_gpudims
13
+ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
14
+ from tinygrad.uop.decompositions import get_late_rewrite_patterns
15
+ from tinygrad.codegen.expander import migrate_indexing, expander
16
+ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
17
+ ReduceContext, correct_load_store, pm_render
18
+ from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
19
+ from tinygrad.codegen.opt import pm_optimize
20
+ from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
21
+
22
+ @dataclass
23
+ class RewriteStep:
24
+ pm: PatternMatcher
25
+ ctx: Callable[[UOp], Any]|None = None
26
+ name: str|None = None
27
+ bottom_up: bool = False
28
+ def __call__(self, sink:UOp):
29
+ return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
30
+
31
+ def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
32
+
33
+ rewrites_for_views = [
34
+ RewriteStep(view_left, name="Main View Left"),
35
+ RewriteStep(view_right, name="Main View Right"),
36
+ RewriteStep(view_left+fix_kernel_ops, bottom_up=True, name="Finalize Kernel"),
37
+ ]
38
+
39
+ rewrites_for_linearizer = [
40
+ RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
41
+ RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
42
+ RewriteStep(block_merge, name="Linearizer: Merge Blocks"),
43
+ RewriteStep(pm_finalize, name="Linearizer: Finalize")]
44
+
45
+ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]:
46
+ # cache with the values of the context vars
47
+ return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
48
+
49
+ @functools.cache
50
+ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
51
+ # ** lowerer (rewrite_shapetracker_with_index) **
52
+ ret: list[RewriteStep] = []
53
+
54
+ # view pushing
55
+ ret.extend(rewrites_for_views)
56
+
57
+ # this is kernel.py
58
+ ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
59
+
60
+ if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
61
+ ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
62
+
63
+ # ** expander (expand_rewrite) **
64
+ ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
65
+
66
+ # expand
67
+ ret.append(RewriteStep(sym+expander, name="expander"))
68
+
69
+ # ** devectorizer (full_graph_rewrite) **
70
+ # remove reduce
71
+ ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
72
+
73
+ # add gpu dims (late)
74
+ ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
75
+
76
+ # devectorize (TODO: does this need opts?)
77
+ if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
78
+ elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
79
+ else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
80
+ ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
81
+
82
+ supported_ops = tuple(opts.code_for_op.keys())
83
+ extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
84
+
85
+ # optional pre matcher
86
+ if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
87
+
88
+ # decompositions
89
+ pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)
90
+ ret.append(RewriteStep(pm_decomp, lambda _: opts.device, name="decompositions"))
91
+
92
+ # final rules for the renderer (without sym)
93
+ pm_final_rewrite = pm_decomp+pm_render+extra_matcher
94
+ ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
95
+
96
+ # return the list (with optional linearizer)
97
+ return ret + (rewrites_for_linearizer if linearizer else [])
98
+
99
+ def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp:
100
+ return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer))
101
+
102
+ def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
103
+ """
104
+ Function to transform the Kernel UOp graph into a linearized program.
105
+
106
+ Args:
107
+ sink: The Ops.SINK rooting the Kernel graph.
108
+ opts: The Renderer (can change how things are processed, fix this).
109
+
110
+ Returns:
111
+ Linear program in UOps.
112
+ """
113
+
114
+ lst = list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)
115
+ if __debug__: type_verify(lst)
116
+ return lst