tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
Copyright (c)
|
1
|
+
Copyright (c) 2024, the tiny corp
|
2
2
|
|
3
3
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4
4
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: tinygrad
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.9.0
|
4
4
|
Summary: You like pytorch? You like micrograd? You love tinygrad! <3
|
5
5
|
Author: George Hotz
|
6
6
|
License: MIT
|
@@ -11,11 +11,16 @@ Description-Content-Type: text/markdown
|
|
11
11
|
License-File: LICENSE
|
12
12
|
Requires-Dist: numpy
|
13
13
|
Requires-Dist: tqdm
|
14
|
-
Requires-Dist: gpuctypes
|
15
14
|
Requires-Dist: pyobjc-framework-Metal ; platform_system == "Darwin"
|
16
15
|
Requires-Dist: pyobjc-framework-libdispatch ; platform_system == "Darwin"
|
17
16
|
Provides-Extra: arm
|
18
17
|
Requires-Dist: unicorn ; extra == 'arm'
|
18
|
+
Provides-Extra: docs
|
19
|
+
Requires-Dist: mkdocs-material ; extra == 'docs'
|
20
|
+
Requires-Dist: mkdocstrings[python] ; extra == 'docs'
|
21
|
+
Requires-Dist: markdown-callouts ; extra == 'docs'
|
22
|
+
Requires-Dist: markdown-exec[ansi] ; extra == 'docs'
|
23
|
+
Requires-Dist: black ; extra == 'docs'
|
19
24
|
Provides-Extra: linting
|
20
25
|
Requires-Dist: pylint ; extra == 'linting'
|
21
26
|
Requires-Dist: mypy ; extra == 'linting'
|
@@ -30,7 +35,7 @@ Requires-Dist: torch ; extra == 'testing'
|
|
30
35
|
Requires-Dist: pillow ; extra == 'testing'
|
31
36
|
Requires-Dist: pytest ; extra == 'testing'
|
32
37
|
Requires-Dist: pytest-xdist ; extra == 'testing'
|
33
|
-
Requires-Dist: onnx ==1.
|
38
|
+
Requires-Dist: onnx ==1.16.0 ; extra == 'testing'
|
34
39
|
Requires-Dist: onnx2torch ; extra == 'testing'
|
35
40
|
Requires-Dist: opencv-python ; extra == 'testing'
|
36
41
|
Requires-Dist: tabulate ; extra == 'testing'
|
@@ -41,12 +46,19 @@ Requires-Dist: tiktoken ; extra == 'testing'
|
|
41
46
|
Requires-Dist: librosa ; extra == 'testing'
|
42
47
|
Requires-Dist: networkx ; extra == 'testing'
|
43
48
|
Requires-Dist: hypothesis ; extra == 'testing'
|
49
|
+
Requires-Dist: nibabel ; extra == 'testing'
|
50
|
+
Provides-Extra: testing_tf
|
51
|
+
Requires-Dist: tensorflow ==2.15.1 ; extra == 'testing_tf'
|
52
|
+
Requires-Dist: tensorflow-addons ; extra == 'testing_tf'
|
44
53
|
Provides-Extra: triton
|
45
54
|
Requires-Dist: triton-nightly >=2.1.0.dev20231014192330 ; extra == 'triton'
|
46
55
|
|
47
56
|
<div align="center">
|
48
57
|
|
49
|
-
|
58
|
+
<picture>
|
59
|
+
<source media="(prefers-color-scheme: light)" srcset="/docs/logo_tiny_light.svg">
|
60
|
+
<img alt="tiny corp logo" src="/docs/logo_tiny_dark.svg" width="50%" height="50%">
|
61
|
+
</picture>
|
50
62
|
|
51
63
|
tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) and [karpathy/micrograd](https://github.com/karpathy/micrograd). Maintained by [tiny corp](https://tinygrad.org).
|
52
64
|
|
@@ -122,17 +134,14 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers
|
|
122
134
|
|
123
135
|
tinygrad already supports numerous accelerators, including:
|
124
136
|
|
125
|
-
- [x] [CPU](tinygrad/runtime/ops_cpu.py)
|
126
137
|
- [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
|
127
|
-
- [x] [C Code
|
138
|
+
- [x] [CLANG (C Code)](tinygrad/runtime/ops_clang.py)
|
128
139
|
- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
|
129
140
|
- [x] [METAL](tinygrad/runtime/ops_metal.py)
|
130
141
|
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
|
131
|
-
- [x] [
|
132
|
-
- [x] [HIP](tinygrad/runtime/ops_hip.py)
|
142
|
+
- [x] [HSA](tinygrad/runtime/ops_hsa.py)
|
133
143
|
|
134
144
|
And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
|
135
|
-
More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md).
|
136
145
|
|
137
146
|
## Installation
|
138
147
|
|
@@ -193,6 +202,7 @@ We'll start with what will get your PR closed with a pointer to this section:
|
|
193
202
|
- All docs and whitespace changes will be closed unless you are a well-known contributor. The people writing the docs should be those who know the codebase the absolute best. People who have not demonstrated that shouldn't be messing with docs. Whitespace changes are both useless *and* carry a risk of introducing bugs.
|
194
203
|
- Anything you claim is a "speedup" must be benchmarked. In general, the goal is simplicity, so even if your PR makes things marginally faster, you have to consider the tradeoff with maintainablity and readablity.
|
195
204
|
- In general, the code outside the core `tinygrad/` folder is not well tested, so unless the current code there is broken, you shouldn't be changing it.
|
205
|
+
- If your PR looks "complex", is a big diff, or adds lots of lines, it won't be reviewed or merged. Consider breaking it up into smaller PRs that are individually clear wins. A common pattern I see is prerequisite refactors before adding new functionality. If you can (cleanly) refactor to the point that the feature is a 3 line change, this is great, and something easy for us to review.
|
196
206
|
|
197
207
|
Now, what we want:
|
198
208
|
|
@@ -0,0 +1,60 @@
|
|
1
|
+
tinygrad/__init__.py,sha256=jC-35zswLSXLuRRThG_o6yar6qQjLCqmeaFCj_XKN08,449
|
2
|
+
tinygrad/device.py,sha256=zXcrFjBsiV1rW0aXupszDjD98TWLHin7u8pBd5fdJqo,10446
|
3
|
+
tinygrad/dtype.py,sha256=xg2BlFIPcQw0onHW_0ktGXjved9SXgQcLNrqe6gCXto,6221
|
4
|
+
tinygrad/function.py,sha256=0xkWst2tRsOeN6YcQS65MfVfWwKQYFkAacgkTys0VdQ,9616
|
5
|
+
tinygrad/helpers.py,sha256=XI8MIeBE35wQ4q0NEsUCvkj3QdY0adI80SfCbOySOVI,12773
|
6
|
+
tinygrad/lazy.py,sha256=xqaEqXaIpt_77SP_2U6Pyfw8YeGd_0PzNDXJOOnRJ24,13379
|
7
|
+
tinygrad/multi.py,sha256=gyGXYVviaPfzAkoJjLFUiusVd3no6HRJunOOxD0DaaY,11362
|
8
|
+
tinygrad/ops.py,sha256=aNk1jLuJl--Z_u8DE8Du1iN1AFhMH-6Le5mnvRHLDvI,7124
|
9
|
+
tinygrad/tensor.py,sha256=nznRGHH7-64bMpFeu8gvdSDfJ5CEEoezIcHSxPgeJ7k,129223
|
10
|
+
tinygrad/codegen/kernel.py,sha256=RRRmOX3iOOgu5ISABW_UTVh5vfGdLFfK1UOtKtpghuY,38169
|
11
|
+
tinygrad/codegen/linearizer.py,sha256=jxwEcxxcpWOvYlIgXmlGkgNYZ4sDykiY_5Ve3a8tpYg,27622
|
12
|
+
tinygrad/codegen/uops.py,sha256=yKS-3w9teuS_3BLnHAN4vWtSRvZHmsx194YRBXMOhFI,21872
|
13
|
+
tinygrad/engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
tinygrad/engine/graph.py,sha256=eEbb17qbJ0A-2VjN4l7SCbA4yI7jh6YN6EybPShmcbg,5221
|
15
|
+
tinygrad/engine/jit.py,sha256=TrdZQEXnF-SowCAyy69iEp85NPFFtOVaIz_2i1cwCvQ,11049
|
16
|
+
tinygrad/engine/realize.py,sha256=H2CgiLWRqTQSr9udJb1iLX7DFdLcPLwxSJZ2bUfHXDs,11077
|
17
|
+
tinygrad/engine/schedule.py,sha256=I3OxiNwveWbVuhhFWdR1B5G8goEyqXOHWe7UQQ3Ogz8,18487
|
18
|
+
tinygrad/engine/search.py,sha256=M11qHlufIffBS9b7mjk8gniyoGRaengwXnMfUQTHmyw,11289
|
19
|
+
tinygrad/nn/__init__.py,sha256=DoHrq9pUFs1vm9FUA5eet5_tvhlZJlC_uC4kBqo98kI,12932
|
20
|
+
tinygrad/nn/datasets.py,sha256=Mvf_0eCEEqtB9-8iiyDFhm-B7rTyWXwmSY_3xjeHleo,458
|
21
|
+
tinygrad/nn/optim.py,sha256=zf85kwumpS17fk1NBjUfe7tOUE7-XH37SL-LjgAfva8,6799
|
22
|
+
tinygrad/nn/state.py,sha256=nGR05s3kuDNp9lliCIr4-6Ek7Korha7jCAWie5S2rB4,10138
|
23
|
+
tinygrad/renderer/__init__.py,sha256=-LjQ9tC2rI8fveaS_xn24X_knXKILFj-iZFcRTk8fNM,2672
|
24
|
+
tinygrad/renderer/assembly.py,sha256=MD-SSC7-Nqwt3zrwe0aDXVX08W9Ox6Vj_byPS1k1bAQ,17923
|
25
|
+
tinygrad/renderer/cstyle.py,sha256=tFWWW-egorLFEDwX6fA9-rYxvNLc67LjxlZ6JzrWCF0,24043
|
26
|
+
tinygrad/renderer/llvmir.py,sha256=BZViWXj2G6JtEcOgc-CtnIj-d9xP0ZjgNXdTKQT_PJ8,10315
|
27
|
+
tinygrad/runtime/ops_amd.py,sha256=3jOrFqxk8JkPX043tEUFLvyKSgX7Fls785g_gOkdzVM,31811
|
28
|
+
tinygrad/runtime/ops_clang.py,sha256=XWqwobReRdu-Tj-chbWEJFMx6AQfgdGCcpdWcLWUTOQ,1468
|
29
|
+
tinygrad/runtime/ops_cuda.py,sha256=cgeoVpY9bOGU22Eh78XR5YOYY2cgsJt4Vnxl6u8N6co,10840
|
30
|
+
tinygrad/runtime/ops_disk.py,sha256=75-iihZxkhNvA5O3VaW61LOXwmlSX4XwegpnV1C4D5A,2738
|
31
|
+
tinygrad/runtime/ops_gpu.py,sha256=FB3Fp-VVEDGEt_6CfJxsM_TWzhp5giXCP1TSSRMXE80,7532
|
32
|
+
tinygrad/runtime/ops_hsa.py,sha256=YNQLqZjJ9twTJRKS41l2oIrncOAu3wdOdBegs9zYlgo,16188
|
33
|
+
tinygrad/runtime/ops_llvm.py,sha256=dODiyVSlPofHyDIZrD-V74h8W1d94VPnr_-A4gNbSO4,2229
|
34
|
+
tinygrad/runtime/ops_metal.py,sha256=fGSNpwmYIHaif9a5SiwyMX2bub-r5hTNpnrqlaPMeUc,5815
|
35
|
+
tinygrad/runtime/ops_npy.py,sha256=qaAi0AEo6nt7iZ-eWqM8z2aQfNJgZUpmBCEDmrIzWL0,369
|
36
|
+
tinygrad/runtime/ops_nv.py,sha256=PCMAHMrW4J7esgnkpwq3bB91Q3h6hBATr8JuykR9vcA,37633
|
37
|
+
tinygrad/runtime/ops_python.py,sha256=mmsDj1hJ3BtreAq5dfCuuUGbgoIhCKlVwNqMDmXBISs,10871
|
38
|
+
tinygrad/runtime/autogen/amd_gpu.py,sha256=1NDH0ualiZ8OtgTjaYcQ1HjKs_SQ7eUHuJvdrDodvCk,65022
|
39
|
+
tinygrad/runtime/autogen/comgr.py,sha256=Z99Y6K8D_nuMpOs0qDWiA0MV-RxueV65o2OyPFdcsHE,38563
|
40
|
+
tinygrad/runtime/autogen/cuda.py,sha256=GgRl4AfU54JG0G1XJj2dq2FbrUZ8XG_AnFrPAZJpSSg,250380
|
41
|
+
tinygrad/runtime/autogen/hip.py,sha256=1yUHDCwL3KkD15if2Q1Ud3GbJiR7DxsNorKZTCINw54,245532
|
42
|
+
tinygrad/runtime/autogen/hsa.py,sha256=tGpnXUhhQkAIEr0yyCxRImzajmt-nN0KzJn4KnT_bH8,270073
|
43
|
+
tinygrad/runtime/autogen/kfd.py,sha256=dDmLFL6HL_QXRW2rZOCArY55PRuXuLN9563XCipV2jM,29935
|
44
|
+
tinygrad/runtime/autogen/nv_gpu.py,sha256=K9WwwdIitHrY2AXpYy8bbdD9aEwdbz9vL7748pz6Re0,1672024
|
45
|
+
tinygrad/runtime/autogen/opencl.py,sha256=aW-luGFF5PXFEZ6SgrGANhA9qpkI-fZyEsmDfpez2Ss,82654
|
46
|
+
tinygrad/runtime/driver/hip_comgr.py,sha256=rFQRsOYo4XquwcHFTe2mGzMfozdL9hitO3DRYBDFSuM,3376
|
47
|
+
tinygrad/runtime/driver/hsa.py,sha256=PoNy8gHBPoRUhUZligFp0z_Le9fyEXbJrnlgwInt_R0,7218
|
48
|
+
tinygrad/runtime/graph/clang.py,sha256=10Bs64J0r12g6upqCHVoK3LoTrdbBBHQD43efMhlBjo,1957
|
49
|
+
tinygrad/runtime/graph/cuda.py,sha256=LNx6RQLcQSKMlHfVK5r_efujN0lRPhKqi8yp249OAIs,5265
|
50
|
+
tinygrad/runtime/graph/hcq.py,sha256=mspwzroBTwNNHDob7oK-JCt48mhuIhX_G0qNYvFVuVM,8089
|
51
|
+
tinygrad/runtime/graph/hsa.py,sha256=UJgSg2irrKT87LBZ3DfaGmoK7rJk8OZhIHEHhtF8rUE,10035
|
52
|
+
tinygrad/runtime/graph/metal.py,sha256=bwB6uAsqjEbwv5ML5ziWduBtmTpseJboo6J9ssVa4v4,4579
|
53
|
+
tinygrad/shape/shapetracker.py,sha256=hWqh2uWsbBp3lKlRpY8Fj1oTWvEx1YwVsKl0QiA-QnU,6334
|
54
|
+
tinygrad/shape/symbolic.py,sha256=hn2khLoHAJSwyZ91i679oJZCLTaz0Sf2dUG-HRJMtVw,16688
|
55
|
+
tinygrad/shape/view.py,sha256=KMf_KzNwXmcX1NbFPq862-Jv_E6TgeO27lcPjrAweF4,17092
|
56
|
+
tinygrad-0.9.0.dist-info/LICENSE,sha256=ABRhUPEILzINYIukgazD-_rPipkUNUwslrb0RxnV6Xc,1058
|
57
|
+
tinygrad-0.9.0.dist-info/METADATA,sha256=oyGO3WSmMQ7NTAK3RGk0ZXCkr-L3XKltKYhYrKEuifk,10227
|
58
|
+
tinygrad-0.9.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
59
|
+
tinygrad-0.9.0.dist-info/top_level.txt,sha256=vDABMCWBFQnx2kn9Azueu88FP-1klQdePoHikQhHymc,9
|
60
|
+
tinygrad-0.9.0.dist-info/RECORD,,
|
tinygrad/features/image.py
DELETED
@@ -1,93 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
from tinygrad.helpers import prod, IMAGE, getenv, DEBUG
|
3
|
-
from tinygrad.dtype import dtypes
|
4
|
-
|
5
|
-
# *** image Tensor function replacements ***
|
6
|
-
|
7
|
-
def image_dot(self, w):
|
8
|
-
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
9
|
-
n1, n2 = len(self.shape), len(w.shape)
|
10
|
-
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
11
|
-
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
12
|
-
bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
|
13
|
-
out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
|
14
|
-
|
15
|
-
# NOTE: with NHWC we can remove the transposes
|
16
|
-
# bs x groups*cin x H x W
|
17
|
-
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
|
18
|
-
# groups*cout x cin x H, W
|
19
|
-
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
20
|
-
return image_conv2d(cx, cw, groups=groups).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
21
|
-
|
22
|
-
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
|
23
|
-
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
24
|
-
|
25
|
-
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
26
|
-
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
|
27
|
-
|
28
|
-
# hack for non multiples of 4 on cin
|
29
|
-
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
|
30
|
-
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
|
31
|
-
added_input_channels = 4 - (cin % 4)
|
32
|
-
w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
|
33
|
-
x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
|
34
|
-
cin = cin + added_input_channels
|
35
|
-
x = x.reshape(bs, groups*cin, iy, ix)
|
36
|
-
|
37
|
-
# hack for non multiples of 4 on rcout
|
38
|
-
added_output_channels = 0
|
39
|
-
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
40
|
-
added_output_channels = 4 - (rcout % 4)
|
41
|
-
rcout += added_output_channels
|
42
|
-
cout = groups * rcout
|
43
|
-
w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
|
44
|
-
|
45
|
-
# packed (note: flipping bs and iy would make the auto-padding work)
|
46
|
-
x = x.permute(0,2,3,1)
|
47
|
-
cin_last = iy == 1 and ix == 1
|
48
|
-
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
|
49
|
-
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
|
50
|
-
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
|
51
|
-
|
52
|
-
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
53
|
-
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
54
|
-
x, w = x.contiguous(), w.contiguous()
|
55
|
-
|
56
|
-
# expand out
|
57
|
-
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
58
|
-
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
|
59
|
-
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
60
|
-
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
61
|
-
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
62
|
-
|
63
|
-
# padding
|
64
|
-
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
65
|
-
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
66
|
-
|
67
|
-
# prepare input
|
68
|
-
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
69
|
-
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
70
|
-
|
71
|
-
# prepare weights
|
72
|
-
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
|
73
|
-
|
74
|
-
# the conv!
|
75
|
-
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1))
|
76
|
-
|
77
|
-
# undo hack for non multiples of 4 on C.rcout
|
78
|
-
if added_output_channels != 0:
|
79
|
-
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
80
|
-
cout = groups * (rcout - added_output_channels)
|
81
|
-
|
82
|
-
# NCHW output
|
83
|
-
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
84
|
-
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
85
|
-
|
86
|
-
# *** images have weird indexing requirements ***
|
87
|
-
|
88
|
-
from tinygrad.shape.symbolic import Node
|
89
|
-
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
|
90
|
-
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
|
91
|
-
# TODO: bring back the valid removal logic (correct!)
|
92
|
-
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
|
93
|
-
return (idx, idy), valid
|
tinygrad/features/multi.py
DELETED
@@ -1,103 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
from typing import Optional, Union, Any, Tuple, List
|
3
|
-
import functools
|
4
|
-
from tinygrad.helpers import all_same, dedup
|
5
|
-
from tinygrad.dtype import DType
|
6
|
-
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
|
7
|
-
from tinygrad.lazy import LazyBuffer, create_schedule
|
8
|
-
from tinygrad.shape.shapetracker import ShapeTracker, sint
|
9
|
-
|
10
|
-
def all_reduce(lbs):
|
11
|
-
# TODO: replace this with ring reduce
|
12
|
-
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
13
|
-
|
14
|
-
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
15
|
-
assert lbs[0].shape[axis] % len(lbs) == 0, f"{lbs[0].shape=} {axis=} {len(lbs)=}"
|
16
|
-
sz = lbs[0].shape[axis] // len(lbs)
|
17
|
-
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,sz*(i+1)) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
18
|
-
|
19
|
-
class MultiLazyBuffer:
|
20
|
-
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]):
|
21
|
-
assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs) >= 2, "all lbs must be LazyBuffers, and we need at least two of them"
|
22
|
-
assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st"
|
23
|
-
self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs)
|
24
|
-
self.shape = tuple(s*len(self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape))
|
25
|
-
|
26
|
-
def __repr__(self):
|
27
|
-
return f"<MLB{chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
28
|
-
|
29
|
-
@staticmethod
|
30
|
-
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
|
31
|
-
lbs = [lb.contiguous() if lb.base != lb else lb] * len(devices)
|
32
|
-
return MultiLazyBuffer([lb.copy_to_device(d).contiguous() for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis)
|
33
|
-
|
34
|
-
def copy_to_device(self, device:str) -> LazyBuffer:
|
35
|
-
if self.axis is None: return self.lbs[0].copy_to_device(device)
|
36
|
-
sz = self.lbs[0].shape[self.axis]
|
37
|
-
llbs = []
|
38
|
-
for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]):
|
39
|
-
pad_arg = tuple((0,0) if a != self.axis else (sz*i,(s*len(self.lbs))-sz*(i+1)) for a,s in enumerate(lb.shape))
|
40
|
-
llbs.append(lb.pad(pad_arg))
|
41
|
-
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
|
42
|
-
|
43
|
-
# TODO: fix this
|
44
|
-
def is_unrealized_contiguous_const(self): return False
|
45
|
-
|
46
|
-
# passthroughs
|
47
|
-
def schedule(self, seen=None): return create_schedule(self.lbs, seen)
|
48
|
-
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis)
|
49
|
-
def const(self, val:Union[float, int]) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis)
|
50
|
-
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis)
|
51
|
-
|
52
|
-
# elementwise is simple
|
53
|
-
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
|
54
|
-
msrcs = (self,)+in_srcs
|
55
|
-
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
|
56
|
-
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
57
|
-
|
58
|
-
# NOTE: they all have to share an axis, we always choose [-1]
|
59
|
-
axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
|
60
|
-
srcs = []
|
61
|
-
for mlb in msrcs:
|
62
|
-
if mlb.axis == axis: srcs.append(mlb.lbs)
|
63
|
-
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
|
64
|
-
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
|
65
|
-
return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis)
|
66
|
-
|
67
|
-
def _shape_to_single_shard(self, shape): return tuple(s//len(self.lbs) if a == self.axis else s for a,s in enumerate(shape))
|
68
|
-
|
69
|
-
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer:
|
70
|
-
if self.axis is not None and new_shape[self.axis] == 1:
|
71
|
-
# all-reduce on sharded axes
|
72
|
-
return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None)
|
73
|
-
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
74
|
-
return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape)) for x in self.lbs], self.axis)
|
75
|
-
|
76
|
-
# *** movement ops ***
|
77
|
-
|
78
|
-
def reshape(self, arg:Tuple[sint, ...]):
|
79
|
-
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None)
|
80
|
-
# TODO: this can be wrong
|
81
|
-
st = ShapeTracker.from_shape(self.shape)
|
82
|
-
rs = st.real_strides()[self.axis]
|
83
|
-
new_axis = st.reshape(arg).real_strides().index(rs)
|
84
|
-
narg = tuple(s//len(self.lbs) if a == new_axis else s for a,s in enumerate(arg))
|
85
|
-
return MultiLazyBuffer([x.reshape(narg) for x in self.lbs], new_axis)
|
86
|
-
|
87
|
-
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
88
|
-
assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis"
|
89
|
-
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis)
|
90
|
-
def expand(self, arg:Tuple[sint, ...]):
|
91
|
-
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
92
|
-
assert self.axis is None or arg[self.axis] == self.lbs[0].shape[self.axis] * len(self.lbs), "expand not supported on sharded axis"
|
93
|
-
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg)) for x in self.lbs], self.axis)
|
94
|
-
def permute(self, arg:Tuple[int, ...]):
|
95
|
-
# all permutes supported!
|
96
|
-
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None)
|
97
|
-
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
|
98
|
-
assert self.axis is None or arg[self.axis] == (0, self.lbs[0].shape[self.axis] * len(self.lbs)), "shrinking not supported on sharded axis"
|
99
|
-
narg = tuple((s1//len(self.lbs), s2//len(self.lbs)) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))
|
100
|
-
return MultiLazyBuffer([x.shrink(narg) for x in self.lbs], self.axis)
|
101
|
-
def stride(self, arg:Tuple[int, ...]):
|
102
|
-
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
|
103
|
-
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis)
|
tinygrad/features/search.py
DELETED
@@ -1,160 +0,0 @@
|
|
1
|
-
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
2
|
-
import itertools, random, math, time, multiprocessing, traceback, signal
|
3
|
-
from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner
|
4
|
-
from tinygrad.ops import MemBuffer, LazyOp
|
5
|
-
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
6
|
-
from tinygrad.dtype import ImageDType
|
7
|
-
from tinygrad.codegen.linearizer import Linearizer
|
8
|
-
from collections import defaultdict
|
9
|
-
from tinygrad.tensor import Tensor
|
10
|
-
from tinygrad.shape.symbolic import sym_infer
|
11
|
-
|
12
|
-
from tinygrad.codegen.kernel import Opt, OptOps
|
13
|
-
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7] for axis in range(6)]
|
14
|
-
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4] for axis in range(4)]
|
15
|
-
actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
|
16
|
-
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256] for axis in range(3)]
|
17
|
-
actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
|
18
|
-
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
19
|
-
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),]
|
20
|
-
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
21
|
-
|
22
|
-
def _get_test_global_size(global_size, max_global_size, var_vals):
|
23
|
-
test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
|
24
|
-
while prod(test_global_size) > max_global_size:
|
25
|
-
for j in range(len(global_size)-1,-1,-1):
|
26
|
-
if test_global_size[j] > 16:
|
27
|
-
test_global_size[j] //= 2
|
28
|
-
factor *= 2
|
29
|
-
break
|
30
|
-
return test_global_size, factor
|
31
|
-
|
32
|
-
def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): # noqa: E501
|
33
|
-
factor = 1
|
34
|
-
if global_size is not None and max_global_size is not None:
|
35
|
-
global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
|
36
|
-
car = CompiledASTRunner(ast, name, "", lib, global_size, local_size).build(rdev.runtime)
|
37
|
-
tms = []
|
38
|
-
for _ in range(cnt):
|
39
|
-
if clear_l2:
|
40
|
-
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
|
41
|
-
tms.append(car(rawbufs, var_vals, wait=True, do_update_stats=False)*factor)
|
42
|
-
if early_stop is not None and early_stop < tms[-1]: break
|
43
|
-
return tms
|
44
|
-
|
45
|
-
def _compile_linearizer(rdev:Compiled, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]]]:
|
46
|
-
lin.linearize()
|
47
|
-
src = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
|
48
|
-
return rdev.compiler(src), lin.global_size, lin.local_size
|
49
|
-
|
50
|
-
def _try_compile_linearized_w_idx(x):
|
51
|
-
try: return (x[0], _compile_linearizer(cast(Compiled, Device[Device.DEFAULT]), x[1], "test"))
|
52
|
-
except Exception:
|
53
|
-
if DEBUG >= 4: traceback.print_exc()
|
54
|
-
return (x[0], None)
|
55
|
-
|
56
|
-
# workers should ignore ctrl c
|
57
|
-
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
58
|
-
|
59
|
-
# *** external API ***
|
60
|
-
|
61
|
-
# get (scrap) buffers for timing the linearizer
|
62
|
-
def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
|
63
|
-
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
|
64
|
-
for x in lin.membufs: bufsts[x.idx].append(x)
|
65
|
-
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
|
66
|
-
for k,lx in bufsts.items():
|
67
|
-
buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
|
68
|
-
rawbufs[k] = Buffer(Device.DEFAULT, buf_size, lx[0].dtype)
|
69
|
-
assert all(r is not None for r in rawbufs)
|
70
|
-
return cast(List[Buffer], rawbufs)
|
71
|
-
|
72
|
-
# get dictionary of all possible actions
|
73
|
-
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
|
74
|
-
acted_lins = {0:lin} if include_0 else {}
|
75
|
-
for i,a in enumerate(actions):
|
76
|
-
if a.axis is not None and a.axis >= lin.shape_len: continue
|
77
|
-
if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
|
78
|
-
lin2 = lin.copy()
|
79
|
-
try:
|
80
|
-
lin2.apply_opt(a)
|
81
|
-
up, lcl = 1, 1
|
82
|
-
for s,c in zip(lin2.full_shape, lin2.colors()):
|
83
|
-
if c in {"magenta", "yellow"}: up *= s
|
84
|
-
if c in {"cyan", "green", "white"}: lcl *= s
|
85
|
-
if up > 256 or lcl > 256: continue
|
86
|
-
acted_lins[i+1] = lin2
|
87
|
-
except Exception:
|
88
|
-
pass
|
89
|
-
return acted_lins
|
90
|
-
|
91
|
-
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
92
|
-
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
|
93
|
-
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
94
|
-
ret = lin.copy()
|
95
|
-
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
96
|
-
return ret
|
97
|
-
|
98
|
-
beam: List[Tuple[Linearizer, float]] = []
|
99
|
-
seen_libs = set()
|
100
|
-
|
101
|
-
default_parallel = 1 if Device.DEFAULT in {"CUDA", "HIP"} else 0
|
102
|
-
pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker) if getenv("PARALLEL", default_parallel) else None
|
103
|
-
|
104
|
-
try:
|
105
|
-
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
106
|
-
exiting, st = False, time.perf_counter()
|
107
|
-
dev = Device[Device.DEFAULT]
|
108
|
-
assert isinstance(dev, Compiled)
|
109
|
-
while not exiting:
|
110
|
-
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
111
|
-
timed_lins: List[Tuple[Linearizer, float]] = []
|
112
|
-
for i,proc in (pool.imap_unordered(_try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(_try_compile_linearized_w_idx, enumerate(acted_lins))): # noqa: E501
|
113
|
-
if proc is None: continue
|
114
|
-
lib, global_size, local_size = proc
|
115
|
-
if lib in seen_libs: continue
|
116
|
-
seen_libs.add(lib)
|
117
|
-
tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
118
|
-
timed_lins.append((acted_lins[i], min(tms)))
|
119
|
-
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
120
|
-
|
121
|
-
# done
|
122
|
-
opts = sorted(timed_lins, key=lambda x: x[1])
|
123
|
-
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
|
124
|
-
if not exiting: beam = opts[:amt]
|
125
|
-
assert len(beam) > 0, "no BEAM items succeeded?!?"
|
126
|
-
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
127
|
-
if pool is not None: pool.close() # the pool is closed
|
128
|
-
except KeyboardInterrupt as e:
|
129
|
-
if pool is not None: pool.terminate()
|
130
|
-
raise e
|
131
|
-
|
132
|
-
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
133
|
-
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
134
|
-
return beam[0][0]
|
135
|
-
|
136
|
-
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
|
137
|
-
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
138
|
-
MAX_WORKGROUP = 1024
|
139
|
-
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
140
|
-
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
141
|
-
def try_exec(local_size):
|
142
|
-
try: return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
|
143
|
-
except Exception: return float('inf')
|
144
|
-
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
145
|
-
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
146
|
-
return ret[1]
|
147
|
-
|
148
|
-
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
149
|
-
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501
|
150
|
-
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
151
|
-
|
152
|
-
dev = Device[Device.DEFAULT]
|
153
|
-
assert isinstance(dev, Compiled)
|
154
|
-
|
155
|
-
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
156
|
-
lib, global_size, local_size = _compile_linearizer(dev, lin)
|
157
|
-
tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
|
158
|
-
|
159
|
-
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
160
|
-
return min(tms)
|
tinygrad/graph.py
DELETED
@@ -1,106 +0,0 @@
|
|
1
|
-
import os, atexit
|
2
|
-
from collections import defaultdict
|
3
|
-
from typing import List, Any, DefaultDict
|
4
|
-
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp, GlobalCounters
|
5
|
-
from tinygrad.device import Device
|
6
|
-
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, getenv
|
7
|
-
from tinygrad.codegen.linearizer import UOps, UOp
|
8
|
-
from tinygrad.shape.symbolic import NumNode
|
9
|
-
|
10
|
-
# **** debugging and graphing ****
|
11
|
-
|
12
|
-
if DEBUG >= 2:
|
13
|
-
def print_globalcounters():
|
14
|
-
if GlobalCounters.time_sum_s == 0: return
|
15
|
-
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
|
16
|
-
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
|
17
|
-
atexit.register(print_globalcounters)
|
18
|
-
|
19
|
-
G:Any = None
|
20
|
-
def init_graph():
|
21
|
-
global G
|
22
|
-
if G is not None: return
|
23
|
-
import networkx as nx
|
24
|
-
G = nx.DiGraph()
|
25
|
-
def save_graph_exit():
|
26
|
-
print("saving", G, f"to {GRAPHPATH}.svg")
|
27
|
-
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
28
|
-
# -Gnslimit=100 can make it finish, but you won't like results
|
29
|
-
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
30
|
-
atexit.register(save_graph_exit)
|
31
|
-
|
32
|
-
counts: DefaultDict[type, int] = defaultdict(int)
|
33
|
-
def nm(x):
|
34
|
-
if not hasattr(x, 'node_id'):
|
35
|
-
setattr(x, 'node_id', counts[type(x)])
|
36
|
-
counts[type(x)] += 1
|
37
|
-
return x.node_id
|
38
|
-
|
39
|
-
def get_sop(op: List[Op]):
|
40
|
-
op = [x for x in op if x not in BufferOps]
|
41
|
-
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
42
|
-
if len(op) <= 6: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
|
43
|
-
return str(len(op))
|
44
|
-
|
45
|
-
def str_dtype(dtyp):
|
46
|
-
ret = str(dtyp)[7:]
|
47
|
-
return "" if ret == 'float' else f"\n{ret}"
|
48
|
-
|
49
|
-
def realized_lazybuffer(lb, num):
|
50
|
-
if GRAPH:
|
51
|
-
init_graph()
|
52
|
-
G.nodes[nm(lb)]['style'] = '"filled,bold"'
|
53
|
-
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
|
54
|
-
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num} b:{nm(lb.realized)}"'
|
55
|
-
|
56
|
-
def log_lazybuffer(lb, scheduled=False):
|
57
|
-
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
|
58
|
-
MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
|
59
|
-
if GRAPH:
|
60
|
-
init_graph()
|
61
|
-
if lb.base != lb:
|
62
|
-
offset = lb.st.expr_node(NumNode(0))[0]
|
63
|
-
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
|
64
|
-
G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
|
65
|
-
G.add_edge(nm(lb.base), nm(lb), color='#00000060')
|
66
|
-
lb = lb.base
|
67
|
-
if lb.realized is None:
|
68
|
-
for x in lb.srcs:
|
69
|
-
if nm(x) not in G.nodes: log_lazybuffer(x)
|
70
|
-
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
71
|
-
label = '"' + \
|
72
|
-
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
|
73
|
-
str_dtype(lb.dtype)+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
|
74
|
-
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + '"'
|
75
|
-
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
|
76
|
-
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
|
77
|
-
else:
|
78
|
-
if nm(lb) not in G.nodes:
|
79
|
-
# realized but unseen?
|
80
|
-
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
81
|
-
|
82
|
-
def _tree(lazydata, cycles, cnt, prefix=""):
|
83
|
-
cnt[0] += 1
|
84
|
-
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
85
|
-
if (lid := id(lazydata)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
|
86
|
-
return [f"━⬆︎ goto {cycles[id(lazydata)][0]}: {lazydata.op.name}"]
|
87
|
-
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
|
88
|
-
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
89
|
-
childs = [_tree(c, cycles, cnt) for c in lazydata.src[:]]
|
90
|
-
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
91
|
-
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
92
|
-
|
93
|
-
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata, {}, [-1]))]))
|
94
|
-
|
95
|
-
def graph_uops(uops:List[UOp]):
|
96
|
-
import networkx as nx
|
97
|
-
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
98
|
-
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
|
99
|
-
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
|
100
|
-
G = nx.DiGraph()
|
101
|
-
for u in uops:
|
102
|
-
if u.uop == UOps.END: continue
|
103
|
-
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
|
104
|
-
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
|
105
|
-
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.uops.dot')
|
106
|
-
os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.uops.dot -o {GRAPHPATH}.uops.svg')
|