tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +78 -90
- tinygrad/codegen/linearizer.py +237 -169
- tinygrad/codegen/uops.py +278 -242
- tinygrad/device.py +147 -10
- tinygrad/dtype.py +7 -7
- tinygrad/engine/graph.py +16 -16
- tinygrad/engine/jit.py +39 -36
- tinygrad/engine/realize.py +6 -5
- tinygrad/engine/schedule.py +15 -7
- tinygrad/engine/search.py +6 -3
- tinygrad/function.py +17 -23
- tinygrad/helpers.py +77 -8
- tinygrad/lazy.py +26 -26
- tinygrad/multi.py +13 -9
- tinygrad/nn/__init__.py +1 -1
- tinygrad/nn/datasets.py +2 -1
- tinygrad/nn/state.py +3 -4
- tinygrad/ops.py +49 -16
- tinygrad/renderer/__init__.py +8 -4
- tinygrad/renderer/assembly.py +93 -100
- tinygrad/renderer/cstyle.py +47 -42
- tinygrad/renderer/llvmir.py +30 -30
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +11504 -1
- tinygrad/runtime/autogen/comgr.py +36 -10
- tinygrad/runtime/autogen/hsa.py +146 -14
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/nv_gpu.py +269 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +20 -11
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +3 -2
- tinygrad/runtime/graph/cuda.py +2 -2
- tinygrad/runtime/graph/hcq.py +122 -78
- tinygrad/runtime/ops_amd.py +302 -316
- tinygrad/runtime/ops_cuda.py +3 -3
- tinygrad/runtime/ops_disk.py +70 -5
- tinygrad/runtime/ops_gpu.py +2 -2
- tinygrad/runtime/ops_metal.py +5 -6
- tinygrad/runtime/ops_npy.py +1 -1
- tinygrad/runtime/ops_nv.py +161 -166
- tinygrad/runtime/ops_python.py +20 -16
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +5 -2
- tinygrad/shape/symbolic.py +1 -3
- tinygrad/shape/view.py +34 -19
- tinygrad/tensor.py +219 -135
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/runtime/driver/hsa.py +0 -143
- tinygrad/runtime/graph/hsa.py +0 -171
- tinygrad/runtime/ops_hsa.py +0 -278
- tinygrad-0.9.0.dist-info/RECORD +0 -60
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: tinygrad
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.1
|
4
4
|
Summary: You like pytorch? You like micrograd? You love tinygrad! <3
|
5
5
|
Author: George Hotz
|
6
6
|
License: MIT
|
@@ -10,7 +10,6 @@ Requires-Python: >=3.8
|
|
10
10
|
Description-Content-Type: text/markdown
|
11
11
|
License-File: LICENSE
|
12
12
|
Requires-Dist: numpy
|
13
|
-
Requires-Dist: tqdm
|
14
13
|
Requires-Dist: pyobjc-framework-Metal ; platform_system == "Darwin"
|
15
14
|
Requires-Dist: pyobjc-framework-libdispatch ; platform_system == "Darwin"
|
16
15
|
Provides-Extra: arm
|
@@ -39,6 +38,7 @@ Requires-Dist: onnx ==1.16.0 ; extra == 'testing'
|
|
39
38
|
Requires-Dist: onnx2torch ; extra == 'testing'
|
40
39
|
Requires-Dist: opencv-python ; extra == 'testing'
|
41
40
|
Requires-Dist: tabulate ; extra == 'testing'
|
41
|
+
Requires-Dist: tqdm ; extra == 'testing'
|
42
42
|
Requires-Dist: safetensors ; extra == 'testing'
|
43
43
|
Requires-Dist: transformers ; extra == 'testing'
|
44
44
|
Requires-Dist: sentencepiece ; extra == 'testing'
|
@@ -47,6 +47,7 @@ Requires-Dist: librosa ; extra == 'testing'
|
|
47
47
|
Requires-Dist: networkx ; extra == 'testing'
|
48
48
|
Requires-Dist: hypothesis ; extra == 'testing'
|
49
49
|
Requires-Dist: nibabel ; extra == 'testing'
|
50
|
+
Requires-Dist: bottle ; extra == 'testing'
|
50
51
|
Provides-Extra: testing_tf
|
51
52
|
Requires-Dist: tensorflow ==2.15.1 ; extra == 'testing_tf'
|
52
53
|
Requires-Dist: tensorflow-addons ; extra == 'testing_tf'
|
@@ -64,7 +65,7 @@ tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) an
|
|
64
65
|
|
65
66
|
<h3>
|
66
67
|
|
67
|
-
[Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](
|
68
|
+
[Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](https://docs.tinygrad.org/) | [Discord](https://discord.gg/ZjZadyC7PK)
|
68
69
|
|
69
70
|
</h3>
|
70
71
|
|
@@ -139,7 +140,8 @@ tinygrad already supports numerous accelerators, including:
|
|
139
140
|
- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
|
140
141
|
- [x] [METAL](tinygrad/runtime/ops_metal.py)
|
141
142
|
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
|
142
|
-
- [x] [
|
143
|
+
- [x] [AMD](tinygrad/runtime/ops_amd.py)
|
144
|
+
- [x] [NV](tinygrad/runtime/ops_nv.py)
|
143
145
|
|
144
146
|
And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
|
145
147
|
|
@@ -163,7 +165,7 @@ python3 -m pip install git+https://github.com/tinygrad/tinygrad.git
|
|
163
165
|
|
164
166
|
## Documentation
|
165
167
|
|
166
|
-
Documentation along with a quick start guide can be found
|
168
|
+
Documentation along with a quick start guide can be found on the [docs website](https://docs.tinygrad.org/) built from the [docs/](/docs) directory.
|
167
169
|
|
168
170
|
### Quick example comparing to PyTorch
|
169
171
|
|
@@ -209,7 +211,7 @@ Now, what we want:
|
|
209
211
|
- Bug fixes (with a regression test) are great! This library isn't 1.0 yet, so if you stumble upon a bug, fix it, write a test, and submit a PR, this is valuable work.
|
210
212
|
- Solving bounties! tinygrad [offers cash bounties](https://docs.google.com/spreadsheets/d/1WKHbT-7KOgjEawq5h5Ic1qUWzpfAzuD_J06N1JwOCGs/edit?usp=sharing) for certain improvements to the library. All new code should be high quality and well tested.
|
211
213
|
- Features. However, if you are adding a feature, consider the line tradeoff. If it's 3 lines, there's less of a bar of usefulness it has to meet over something that's 30 or 300 lines. All features must have regression tests. In general with no other constraints, your feature's API should match torch or numpy.
|
212
|
-
- Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win.
|
214
|
+
- Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win. Refactors should pass [process replay](#process-replay-tests).
|
213
215
|
- Tests/fuzzers. If you can add tests that are non brittle, they are welcome. We have some fuzzers in here too, and there's a plethora of bugs that can be found with them and by improving them. Finding bugs, even writing broken tests (that should pass) with `@unittest.expectedFailure` is great. This is how we make progress.
|
214
216
|
- Dead code removal from core `tinygrad/` folder. We don't care about the code in extra, but removing dead code from the core library is great. Less for new people to read and be confused by.
|
215
217
|
|
@@ -225,3 +227,9 @@ python3 -m pip install -e '.[testing]' # install extra deps for testing
|
|
225
227
|
python3 test/test_ops.py # just the ops tests
|
226
228
|
python3 -m pytest test/ # whole test suite
|
227
229
|
```
|
230
|
+
|
231
|
+
#### Process replay tests
|
232
|
+
|
233
|
+
[Process replay](https://github.com/tinygrad/tinygrad/blob/master/test/external/process_replay/process_replay.py) detects changes in the generated kernels of CI tests by comparing them against tinygrad master. If your PR is a refactor or speedup without any expected behavior change, it should include a green process replay pass to get merged.
|
234
|
+
|
235
|
+
You can enable process replay by adding [run_process_replay] to your PR title. [example](https://github.com/tinygrad/tinygrad/pull/4995). Note that you should keep your branch up-to-date with master.
|
@@ -0,0 +1,63 @@
|
|
1
|
+
tinygrad/__init__.py,sha256=jC-35zswLSXLuRRThG_o6yar6qQjLCqmeaFCj_XKN08,449
|
2
|
+
tinygrad/device.py,sha256=oy8vQAn0bLSVnLtu6MfusWF6Axv83xEwrJB3C1sWcPo,18287
|
3
|
+
tinygrad/dtype.py,sha256=LGHdreCNcM-FJTwT3uNrsGIMCSnBHi4ZeGf3K0QHXGU,6199
|
4
|
+
tinygrad/function.py,sha256=1PS3nZBicJmMAZ-IvE0SJAHWrGyRMkWM9-J_Hr5az0E,9473
|
5
|
+
tinygrad/helpers.py,sha256=SQfdz3AxiAT_k7GExSIXKBYqvCugxaF7ftQ9gr8qvFg,16814
|
6
|
+
tinygrad/lazy.py,sha256=mrmxsa6dAQhVAm2UMs7yM18Y-IeJyG4rqyW4-LL5apQ,13336
|
7
|
+
tinygrad/multi.py,sha256=EK9OTqmt2EaEgJIIebZUe944NaoJLTFbnUGxVYIk--k,11520
|
8
|
+
tinygrad/ops.py,sha256=ZOKpgKY_y3BGqu5NpL172me10Lw4SYVkkCUeD3HR7MI,8607
|
9
|
+
tinygrad/tensor.py,sha256=-TRbZ1ydsxgsHUlMDfx6AuEY4qtxUbYhYQ8hK7w8mt0,133764
|
10
|
+
tinygrad/codegen/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
tinygrad/codegen/kernel.py,sha256=Bw53Hb0vfdyyz2JQhE7YV7DqfZh3gVsxtLFABkx8Xps,36961
|
12
|
+
tinygrad/codegen/linearizer.py,sha256=oxqSuzBwpbOtZL1vTxVR4qQqjdvvAUwayhVMltkkX1s,32964
|
13
|
+
tinygrad/codegen/uops.py,sha256=GzEwfeyATi9uydzHnyFeN_V47hrk7pkz3gHePjyp6jY,25454
|
14
|
+
tinygrad/engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
+
tinygrad/engine/graph.py,sha256=2w16XKpAXSkEML8Hl3m1bbPnLDV-MBsaBJLBPx2QzA0,5264
|
16
|
+
tinygrad/engine/jit.py,sha256=xoiApaot299mcjKPbxwUlqu2SJ3Tbyfs_rm5zKuHNMA,11355
|
17
|
+
tinygrad/engine/realize.py,sha256=LsfZIHQ2Zz_79a4LKBU1pMUwbleM1x_oBTyn4D1qRc8,11214
|
18
|
+
tinygrad/engine/schedule.py,sha256=ltOshvCgUU5wHqTDVLhxPVo7DGq7VRA6ECRX6C_09oU,18984
|
19
|
+
tinygrad/engine/search.py,sha256=rLS1Aj_r42L0_uUJRyytqFTpJvXxPdx0LRD7BzvF2dQ,11359
|
20
|
+
tinygrad/nn/__init__.py,sha256=-CH1lNAoxJQjFeGFpp78RYLaD2sD4Q4biJA5uvWwm88,12942
|
21
|
+
tinygrad/nn/datasets.py,sha256=3tau-L-HiopP1leFDXZV-VamCR0DxOhAu6wVOkAVSZY,493
|
22
|
+
tinygrad/nn/optim.py,sha256=zf85kwumpS17fk1NBjUfe7tOUE7-XH37SL-LjgAfva8,6799
|
23
|
+
tinygrad/nn/state.py,sha256=cYxsy6IiMCl1wITZcM04GirsftFOutp-2X0PIow35aY,10111
|
24
|
+
tinygrad/renderer/__init__.py,sha256=OVEBlNy-eaREIn9jrY0tBhVsmMYi1eZK3wW_P_f26-M,2935
|
25
|
+
tinygrad/renderer/assembly.py,sha256=syUM3dMOrmIEkmVWzOS47E2sZXDUWbVXmkofl0RnnWw,18011
|
26
|
+
tinygrad/renderer/cstyle.py,sha256=O2MNxsxZzMCPY7JAC_5w5aUqosJSi4LGdH8PZoA7s2A,24387
|
27
|
+
tinygrad/renderer/llvmir.py,sha256=43iw-cfwukwffibuZQOm3jI3fPPcRT8rs4-JVglVIHU,10192
|
28
|
+
tinygrad/runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
+
tinygrad/runtime/ops_amd.py,sha256=h4ZwP64C82aNuHTUtjRykiEzsdhUwGKtvfNdsXOUXyI,31393
|
30
|
+
tinygrad/runtime/ops_clang.py,sha256=XWqwobReRdu-Tj-chbWEJFMx6AQfgdGCcpdWcLWUTOQ,1468
|
31
|
+
tinygrad/runtime/ops_cuda.py,sha256=LujFWnZvLEaPlebomy9c3C4WAbbaEswdeuJgCNDIiGg,10860
|
32
|
+
tinygrad/runtime/ops_disk.py,sha256=5AGIXyGb_x3zuPQ54M6p0xkROjtQo0MMeiqSUPnQ6MQ,6655
|
33
|
+
tinygrad/runtime/ops_gpu.py,sha256=z9MPEQawWbJGfLyN3KOgxXaENEbi3ULm8xMWcrln1c0,7532
|
34
|
+
tinygrad/runtime/ops_llvm.py,sha256=dODiyVSlPofHyDIZrD-V74h8W1d94VPnr_-A4gNbSO4,2229
|
35
|
+
tinygrad/runtime/ops_metal.py,sha256=cb4xIXp_4wyMjk7E2XuSwcCzh1JEwc21p6hJxB8LPgc,5802
|
36
|
+
tinygrad/runtime/ops_npy.py,sha256=HSDb47vcbPvDiG3EmSGPC0W-EPO9O5R0TqlQKnYxANU,404
|
37
|
+
tinygrad/runtime/ops_nv.py,sha256=ClqazkWrMOheiyRDl1ZY7azMPqFHstnEhlAEKNSkklg,37657
|
38
|
+
tinygrad/runtime/ops_python.py,sha256=2B1ty1fr4SpPq-bfEyA-gDxxvjI7ulS1EAqHwHFn-QE,11051
|
39
|
+
tinygrad/runtime/autogen/amd_gpu.py,sha256=kqu1ygAUxymdwFznsIrXnsJYDa0UDpVndTlTuj9XK_4,581516
|
40
|
+
tinygrad/runtime/autogen/comgr.py,sha256=mhQtuF_vGDrJZD3iyQ_38FrS_2jp3WtImNZPC4TBuAg,39793
|
41
|
+
tinygrad/runtime/autogen/cuda.py,sha256=GgRl4AfU54JG0G1XJj2dq2FbrUZ8XG_AnFrPAZJpSSg,250380
|
42
|
+
tinygrad/runtime/autogen/hip.py,sha256=1yUHDCwL3KkD15if2Q1Ud3GbJiR7DxsNorKZTCINw54,245532
|
43
|
+
tinygrad/runtime/autogen/hsa.py,sha256=7Hsrn17HmChyeFOSX_3Fnzl9c0COtq2Z2ExqGu5FNiU,277716
|
44
|
+
tinygrad/runtime/autogen/io_uring.py,sha256=cWYzokworiNV1x4I6eBmZpqV2nCt5B_rVEDx7pFP2W0,57664
|
45
|
+
tinygrad/runtime/autogen/kfd.py,sha256=dDmLFL6HL_QXRW2rZOCArY55PRuXuLN9563XCipV2jM,29935
|
46
|
+
tinygrad/runtime/autogen/nv_gpu.py,sha256=Hd8p-ay8dbdqFwvGyEtXLX3Ymjr3DP1XSRMAe0zKrq4,1686859
|
47
|
+
tinygrad/runtime/autogen/opencl.py,sha256=aW-luGFF5PXFEZ6SgrGANhA9qpkI-fZyEsmDfpez2Ss,82654
|
48
|
+
tinygrad/runtime/driver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
49
|
+
tinygrad/runtime/driver/hip_comgr.py,sha256=rCS6DaUjkYflI1OvP5XGCOvcy67XCAgeSvQMUb1K9iY,3840
|
50
|
+
tinygrad/runtime/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
51
|
+
tinygrad/runtime/graph/clang.py,sha256=NjobP8AVzNC-CS25FjO5gl2-B8KGJb8K8kLzNXEycPg,2005
|
52
|
+
tinygrad/runtime/graph/cuda.py,sha256=ZXOhUhTzz3L2ut-E0tCTrqTEcDFExx_FaZzBm8HcQIk,5268
|
53
|
+
tinygrad/runtime/graph/hcq.py,sha256=iJcAw-95PXwjkNgQbwwhFmq1m9Zprq8wwPNFkil_SEQ,11877
|
54
|
+
tinygrad/runtime/graph/metal.py,sha256=bwB6uAsqjEbwv5ML5ziWduBtmTpseJboo6J9ssVa4v4,4579
|
55
|
+
tinygrad/shape/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
56
|
+
tinygrad/shape/shapetracker.py,sha256=oJqr3q9HoDwNN8S2VV9H_IpmG_ovGY10CTSaWWZjUz4,6387
|
57
|
+
tinygrad/shape/symbolic.py,sha256=ORyH9bkb94u3ak0XvM7TvZj2n-_RlR9FcYlQqmIUc2M,16617
|
58
|
+
tinygrad/shape/view.py,sha256=H9tdtNWD1xKci_Qhf3EJsWKPD8x4M-e7MAxgoUDqG9k,17999
|
59
|
+
tinygrad-0.9.1.dist-info/LICENSE,sha256=ABRhUPEILzINYIukgazD-_rPipkUNUwslrb0RxnV6Xc,1058
|
60
|
+
tinygrad-0.9.1.dist-info/METADATA,sha256=YApgO_KsHfOTp9m10ytDj_VIVghIq5GP9PTIiIhre2c,10990
|
61
|
+
tinygrad-0.9.1.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
|
62
|
+
tinygrad-0.9.1.dist-info/top_level.txt,sha256=vDABMCWBFQnx2kn9Azueu88FP-1klQdePoHikQhHymc,9
|
63
|
+
tinygrad-0.9.1.dist-info/RECORD,,
|
tinygrad/runtime/driver/hsa.py
DELETED
@@ -1,143 +0,0 @@
|
|
1
|
-
import ctypes, collections
|
2
|
-
import tinygrad.runtime.autogen.hsa as hsa
|
3
|
-
from tinygrad.helpers import init_c_var
|
4
|
-
|
5
|
-
def check(status):
|
6
|
-
if status != 0:
|
7
|
-
hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
|
8
|
-
raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
|
9
|
-
|
10
|
-
# Precalulated AQL info
|
11
|
-
AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
|
12
|
-
EMPTY_SIGNAL = hsa.hsa_signal_t()
|
13
|
-
|
14
|
-
DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
|
15
|
-
DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
16
|
-
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
17
|
-
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
18
|
-
DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
|
19
|
-
|
20
|
-
BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
21
|
-
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
22
|
-
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
23
|
-
BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
|
24
|
-
|
25
|
-
class AQLQueue:
|
26
|
-
def __init__(self, device, sz=-1):
|
27
|
-
self.device = device
|
28
|
-
|
29
|
-
check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
|
30
|
-
queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
|
31
|
-
|
32
|
-
null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
|
33
|
-
self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
|
34
|
-
hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
|
35
|
-
|
36
|
-
self.next_doorbell_index = 0
|
37
|
-
self.queue_base = self.hw_queue.contents.base_address
|
38
|
-
self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
|
39
|
-
self.write_addr = self.queue_base
|
40
|
-
self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
|
41
|
-
self.available_packet_slots = self.hw_queue.contents.size
|
42
|
-
|
43
|
-
check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
|
44
|
-
check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
|
45
|
-
|
46
|
-
def __del__(self):
|
47
|
-
if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
|
48
|
-
|
49
|
-
def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
|
50
|
-
if self.available_packet_slots == 0: self._wait_queue()
|
51
|
-
|
52
|
-
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
|
53
|
-
packet.workgroup_size_x = local_size[0]
|
54
|
-
packet.workgroup_size_y = local_size[1]
|
55
|
-
packet.workgroup_size_z = local_size[2]
|
56
|
-
packet.reserved0 = 0
|
57
|
-
packet.grid_size_x = global_size[0] * local_size[0]
|
58
|
-
packet.grid_size_y = global_size[1] * local_size[1]
|
59
|
-
packet.grid_size_z = global_size[2] * local_size[2]
|
60
|
-
packet.private_segment_size = prg.private_segment_size
|
61
|
-
packet.group_segment_size = prg.group_segment_size
|
62
|
-
packet.kernel_object = prg.handle
|
63
|
-
packet.kernarg_address = kernargs
|
64
|
-
packet.reserved2 = 0
|
65
|
-
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
66
|
-
packet.setup = DISPATCH_KERNEL_SETUP
|
67
|
-
packet.header = DISPATCH_KERNEL_HEADER
|
68
|
-
self._submit_packet()
|
69
|
-
|
70
|
-
def submit_barrier(self, wait_signals=None, completion_signal=None):
|
71
|
-
assert wait_signals is None or len(wait_signals) <= 5
|
72
|
-
if self.available_packet_slots == 0: self._wait_queue()
|
73
|
-
|
74
|
-
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
|
75
|
-
packet.reserved0 = 0
|
76
|
-
packet.reserved1 = 0
|
77
|
-
for i in range(5):
|
78
|
-
packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
|
79
|
-
packet.reserved2 = 0
|
80
|
-
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
81
|
-
packet.header = BARRIER_HEADER
|
82
|
-
self._submit_packet()
|
83
|
-
|
84
|
-
def blit_packets(self, packet_addr, packet_cnt):
|
85
|
-
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
|
86
|
-
|
87
|
-
tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
|
88
|
-
rem_packet_cnt = packet_cnt - tail_blit_packets
|
89
|
-
ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
|
90
|
-
if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
|
91
|
-
|
92
|
-
self._submit_packet(packet_cnt)
|
93
|
-
|
94
|
-
def wait(self):
|
95
|
-
self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
|
96
|
-
hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
97
|
-
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
|
98
|
-
|
99
|
-
def _wait_queue(self, need_packets=1):
|
100
|
-
while self.available_packet_slots < need_packets:
|
101
|
-
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
|
102
|
-
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
|
103
|
-
|
104
|
-
def _submit_packet(self, cnt=1):
|
105
|
-
self.available_packet_slots -= cnt
|
106
|
-
self.next_doorbell_index += cnt
|
107
|
-
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
|
108
|
-
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
|
109
|
-
|
110
|
-
self.write_addr += AQL_PACKET_SIZE * cnt
|
111
|
-
if self.write_addr > self.write_addr_end:
|
112
|
-
self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
|
113
|
-
|
114
|
-
def scan_agents():
|
115
|
-
agents = collections.defaultdict(list)
|
116
|
-
|
117
|
-
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
|
118
|
-
def __scan_agents(agent, data):
|
119
|
-
status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
|
120
|
-
if status == 0: agents[device_type.value].append(agent)
|
121
|
-
return hsa.HSA_STATUS_SUCCESS
|
122
|
-
|
123
|
-
hsa.hsa_iterate_agents(__scan_agents, None)
|
124
|
-
return agents
|
125
|
-
|
126
|
-
def find_memory_pool(agent, segtyp=-1, location=-1):
|
127
|
-
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
|
128
|
-
def __filter_amd_memory_pools(mem_pool, data):
|
129
|
-
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
|
130
|
-
if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
|
131
|
-
|
132
|
-
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
|
133
|
-
if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
|
134
|
-
|
135
|
-
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
|
136
|
-
if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
|
137
|
-
|
138
|
-
ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
|
139
|
-
ret[0] = mem_pool
|
140
|
-
return hsa.HSA_STATUS_INFO_BREAK
|
141
|
-
|
142
|
-
hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
|
143
|
-
return region
|
tinygrad/runtime/graph/hsa.py
DELETED
@@ -1,171 +0,0 @@
|
|
1
|
-
import ctypes, collections, time, itertools
|
2
|
-
from typing import List, Any, Dict, cast, Optional, Tuple
|
3
|
-
from tinygrad.helpers import GraphException, init_c_var, round_up
|
4
|
-
from tinygrad.device import Buffer, BufferOptions
|
5
|
-
from tinygrad.device import Compiled, Device
|
6
|
-
from tinygrad.shape.symbolic import Variable
|
7
|
-
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
8
|
-
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
9
|
-
from tinygrad.engine.jit import MultiGraphRunner
|
10
|
-
import tinygrad.runtime.autogen.hsa as hsa
|
11
|
-
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
|
12
|
-
|
13
|
-
def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
|
14
|
-
|
15
|
-
class VirtAQLQueue(AQLQueue):
|
16
|
-
def __init__(self, device, sz):
|
17
|
-
self.device = device
|
18
|
-
self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
|
19
|
-
self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
|
20
|
-
self.packets_count = 0
|
21
|
-
self.available_packet_slots = sz
|
22
|
-
def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
|
23
|
-
def _submit_packet(self):
|
24
|
-
self.write_addr += AQL_PACKET_SIZE
|
25
|
-
self.packets_count += 1
|
26
|
-
self.available_packet_slots -= 1
|
27
|
-
|
28
|
-
class HSAGraph(MultiGraphRunner):
|
29
|
-
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
30
|
-
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
31
|
-
|
32
|
-
# Check all jit items are compatible.
|
33
|
-
compiled_devices = set()
|
34
|
-
for ji in self.jit_cache:
|
35
|
-
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
|
36
|
-
elif isinstance(ji.prg, BufferXfer):
|
37
|
-
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
|
38
|
-
else: raise GraphException
|
39
|
-
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
|
40
|
-
|
41
|
-
self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
|
42
|
-
|
43
|
-
# Allocate kernel args.
|
44
|
-
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
45
|
-
for ji in self.jit_cache:
|
46
|
-
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
47
|
-
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
|
48
|
-
|
49
|
-
# Fill initial arguments.
|
50
|
-
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
51
|
-
for j,ji in enumerate(self.jit_cache):
|
52
|
-
if not isinstance(ji.prg, CompiledRunner): continue
|
53
|
-
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
|
54
|
-
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
55
|
-
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
|
56
|
-
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
|
57
|
-
|
58
|
-
# Build queues.
|
59
|
-
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
|
60
|
-
self.packets = {}
|
61
|
-
self.transfers = []
|
62
|
-
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
|
63
|
-
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
64
|
-
self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
65
|
-
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
|
66
|
-
|
67
|
-
# Special packet to wait for the world.
|
68
|
-
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
|
69
|
-
for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
|
70
|
-
|
71
|
-
for j,ji in enumerate(self.jit_cache):
|
72
|
-
if isinstance(ji.prg, CompiledRunner):
|
73
|
-
wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
|
74
|
-
for i in range(0, len(wait_signals), 5):
|
75
|
-
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
|
76
|
-
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
77
|
-
|
78
|
-
sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
|
79
|
-
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
|
80
|
-
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
|
81
|
-
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
|
82
|
-
elif isinstance(ji.prg, BufferXfer):
|
83
|
-
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
84
|
-
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
|
85
|
-
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
|
86
|
-
|
87
|
-
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
|
88
|
-
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
89
|
-
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
|
90
|
-
self.ji_to_transfer[j] = len(self.transfers) - 1
|
91
|
-
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
|
92
|
-
|
93
|
-
# Wait for all active signals to finish the graph
|
94
|
-
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
95
|
-
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
|
96
|
-
for dev in self.signals_to_devices[v.handle]:
|
97
|
-
wait_signals_to_finish[dev].append(v)
|
98
|
-
|
99
|
-
self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
100
|
-
for dev in self.devices:
|
101
|
-
wait_signals = wait_signals_to_finish[dev]
|
102
|
-
for i in range(0, max(1, len(wait_signals)), 5):
|
103
|
-
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
|
104
|
-
|
105
|
-
# Zero signals to allow graph to start and execute.
|
106
|
-
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
107
|
-
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
108
|
-
|
109
|
-
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
110
|
-
# Wait and restore signals
|
111
|
-
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
112
|
-
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
113
|
-
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
|
114
|
-
|
115
|
-
# Update rawbuffers
|
116
|
-
for (j,i),input_idx in self.input_replace.items():
|
117
|
-
if j in self.ji_kargs_structs:
|
118
|
-
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
|
119
|
-
else:
|
120
|
-
if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
|
121
|
-
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
|
122
|
-
|
123
|
-
# Update var_vals
|
124
|
-
for j in self.jc_idx_with_updatable_var_vals:
|
125
|
-
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
126
|
-
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
127
|
-
|
128
|
-
# Update launch dims
|
129
|
-
for j in self.jc_idx_with_updatable_launch_dims:
|
130
|
-
gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
|
131
|
-
self.packets[j].workgroup_size_x = lc[0]
|
132
|
-
self.packets[j].workgroup_size_y = lc[1]
|
133
|
-
self.packets[j].workgroup_size_z = lc[2]
|
134
|
-
self.packets[j].grid_size_x = gl[0] * lc[0]
|
135
|
-
self.packets[j].grid_size_y = gl[1] * lc[1]
|
136
|
-
self.packets[j].grid_size_z = gl[2] * lc[2]
|
137
|
-
|
138
|
-
for dev in self.devices:
|
139
|
-
dev.flush_hdp()
|
140
|
-
dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
|
141
|
-
|
142
|
-
for transfer_data in self.transfers:
|
143
|
-
check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
|
144
|
-
|
145
|
-
et = None
|
146
|
-
if wait:
|
147
|
-
st = time.perf_counter()
|
148
|
-
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
149
|
-
et = time.perf_counter() - st
|
150
|
-
|
151
|
-
for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
|
152
|
-
return et
|
153
|
-
|
154
|
-
def alloc_signal(self, reset_on_start=False, wait_on=None):
|
155
|
-
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
156
|
-
if reset_on_start: self.signals_to_reset.append(sync_signal)
|
157
|
-
if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
|
158
|
-
return sync_signal
|
159
|
-
|
160
|
-
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
|
161
|
-
if isinstance(dep, hsa.hsa_signal_t): return dep
|
162
|
-
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
|
163
|
-
if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
|
164
|
-
return packet.completion_signal
|
165
|
-
return None
|
166
|
-
|
167
|
-
def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
|
168
|
-
rdeps = self._access_resources(read, write, new_dependency)
|
169
|
-
wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
|
170
|
-
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
|
171
|
-
return dedup_signals(wait_signals)
|