tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tinygrad
3
- Version: 0.9.0
3
+ Version: 0.9.1
4
4
  Summary: You like pytorch? You like micrograd? You love tinygrad! <3
5
5
  Author: George Hotz
6
6
  License: MIT
@@ -10,7 +10,6 @@ Requires-Python: >=3.8
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Requires-Dist: numpy
13
- Requires-Dist: tqdm
14
13
  Requires-Dist: pyobjc-framework-Metal ; platform_system == "Darwin"
15
14
  Requires-Dist: pyobjc-framework-libdispatch ; platform_system == "Darwin"
16
15
  Provides-Extra: arm
@@ -39,6 +38,7 @@ Requires-Dist: onnx ==1.16.0 ; extra == 'testing'
39
38
  Requires-Dist: onnx2torch ; extra == 'testing'
40
39
  Requires-Dist: opencv-python ; extra == 'testing'
41
40
  Requires-Dist: tabulate ; extra == 'testing'
41
+ Requires-Dist: tqdm ; extra == 'testing'
42
42
  Requires-Dist: safetensors ; extra == 'testing'
43
43
  Requires-Dist: transformers ; extra == 'testing'
44
44
  Requires-Dist: sentencepiece ; extra == 'testing'
@@ -47,6 +47,7 @@ Requires-Dist: librosa ; extra == 'testing'
47
47
  Requires-Dist: networkx ; extra == 'testing'
48
48
  Requires-Dist: hypothesis ; extra == 'testing'
49
49
  Requires-Dist: nibabel ; extra == 'testing'
50
+ Requires-Dist: bottle ; extra == 'testing'
50
51
  Provides-Extra: testing_tf
51
52
  Requires-Dist: tensorflow ==2.15.1 ; extra == 'testing_tf'
52
53
  Requires-Dist: tensorflow-addons ; extra == 'testing_tf'
@@ -64,7 +65,7 @@ tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) an
64
65
 
65
66
  <h3>
66
67
 
67
- [Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](/docs) | [Examples](/examples) | [Showcase](/docs/showcase.md) | [Discord](https://discord.gg/ZjZadyC7PK)
68
+ [Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](https://docs.tinygrad.org/) | [Discord](https://discord.gg/ZjZadyC7PK)
68
69
 
69
70
  </h3>
70
71
 
@@ -139,7 +140,8 @@ tinygrad already supports numerous accelerators, including:
139
140
  - [x] [LLVM](tinygrad/runtime/ops_llvm.py)
140
141
  - [x] [METAL](tinygrad/runtime/ops_metal.py)
141
142
  - [x] [CUDA](tinygrad/runtime/ops_cuda.py)
142
- - [x] [HSA](tinygrad/runtime/ops_hsa.py)
143
+ - [x] [AMD](tinygrad/runtime/ops_amd.py)
144
+ - [x] [NV](tinygrad/runtime/ops_nv.py)
143
145
 
144
146
  And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
145
147
 
@@ -163,7 +165,7 @@ python3 -m pip install git+https://github.com/tinygrad/tinygrad.git
163
165
 
164
166
  ## Documentation
165
167
 
166
- Documentation along with a quick start guide can be found in the [docs/](/docs) directory.
168
+ Documentation along with a quick start guide can be found on the [docs website](https://docs.tinygrad.org/) built from the [docs/](/docs) directory.
167
169
 
168
170
  ### Quick example comparing to PyTorch
169
171
 
@@ -209,7 +211,7 @@ Now, what we want:
209
211
  - Bug fixes (with a regression test) are great! This library isn't 1.0 yet, so if you stumble upon a bug, fix it, write a test, and submit a PR, this is valuable work.
210
212
  - Solving bounties! tinygrad [offers cash bounties](https://docs.google.com/spreadsheets/d/1WKHbT-7KOgjEawq5h5Ic1qUWzpfAzuD_J06N1JwOCGs/edit?usp=sharing) for certain improvements to the library. All new code should be high quality and well tested.
211
213
  - Features. However, if you are adding a feature, consider the line tradeoff. If it's 3 lines, there's less of a bar of usefulness it has to meet over something that's 30 or 300 lines. All features must have regression tests. In general with no other constraints, your feature's API should match torch or numpy.
212
- - Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win.
214
+ - Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win. Refactors should pass [process replay](#process-replay-tests).
213
215
  - Tests/fuzzers. If you can add tests that are non brittle, they are welcome. We have some fuzzers in here too, and there's a plethora of bugs that can be found with them and by improving them. Finding bugs, even writing broken tests (that should pass) with `@unittest.expectedFailure` is great. This is how we make progress.
214
216
  - Dead code removal from core `tinygrad/` folder. We don't care about the code in extra, but removing dead code from the core library is great. Less for new people to read and be confused by.
215
217
 
@@ -225,3 +227,9 @@ python3 -m pip install -e '.[testing]' # install extra deps for testing
225
227
  python3 test/test_ops.py # just the ops tests
226
228
  python3 -m pytest test/ # whole test suite
227
229
  ```
230
+
231
+ #### Process replay tests
232
+
233
+ [Process replay](https://github.com/tinygrad/tinygrad/blob/master/test/external/process_replay/process_replay.py) detects changes in the generated kernels of CI tests by comparing them against tinygrad master. If your PR is a refactor or speedup without any expected behavior change, it should include a green process replay pass to get merged.
234
+
235
+ You can enable process replay by adding [run_process_replay] to your PR title. [example](https://github.com/tinygrad/tinygrad/pull/4995). Note that you should keep your branch up-to-date with master.
@@ -0,0 +1,63 @@
1
+ tinygrad/__init__.py,sha256=jC-35zswLSXLuRRThG_o6yar6qQjLCqmeaFCj_XKN08,449
2
+ tinygrad/device.py,sha256=oy8vQAn0bLSVnLtu6MfusWF6Axv83xEwrJB3C1sWcPo,18287
3
+ tinygrad/dtype.py,sha256=LGHdreCNcM-FJTwT3uNrsGIMCSnBHi4ZeGf3K0QHXGU,6199
4
+ tinygrad/function.py,sha256=1PS3nZBicJmMAZ-IvE0SJAHWrGyRMkWM9-J_Hr5az0E,9473
5
+ tinygrad/helpers.py,sha256=SQfdz3AxiAT_k7GExSIXKBYqvCugxaF7ftQ9gr8qvFg,16814
6
+ tinygrad/lazy.py,sha256=mrmxsa6dAQhVAm2UMs7yM18Y-IeJyG4rqyW4-LL5apQ,13336
7
+ tinygrad/multi.py,sha256=EK9OTqmt2EaEgJIIebZUe944NaoJLTFbnUGxVYIk--k,11520
8
+ tinygrad/ops.py,sha256=ZOKpgKY_y3BGqu5NpL172me10Lw4SYVkkCUeD3HR7MI,8607
9
+ tinygrad/tensor.py,sha256=-TRbZ1ydsxgsHUlMDfx6AuEY4qtxUbYhYQ8hK7w8mt0,133764
10
+ tinygrad/codegen/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ tinygrad/codegen/kernel.py,sha256=Bw53Hb0vfdyyz2JQhE7YV7DqfZh3gVsxtLFABkx8Xps,36961
12
+ tinygrad/codegen/linearizer.py,sha256=oxqSuzBwpbOtZL1vTxVR4qQqjdvvAUwayhVMltkkX1s,32964
13
+ tinygrad/codegen/uops.py,sha256=GzEwfeyATi9uydzHnyFeN_V47hrk7pkz3gHePjyp6jY,25454
14
+ tinygrad/engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ tinygrad/engine/graph.py,sha256=2w16XKpAXSkEML8Hl3m1bbPnLDV-MBsaBJLBPx2QzA0,5264
16
+ tinygrad/engine/jit.py,sha256=xoiApaot299mcjKPbxwUlqu2SJ3Tbyfs_rm5zKuHNMA,11355
17
+ tinygrad/engine/realize.py,sha256=LsfZIHQ2Zz_79a4LKBU1pMUwbleM1x_oBTyn4D1qRc8,11214
18
+ tinygrad/engine/schedule.py,sha256=ltOshvCgUU5wHqTDVLhxPVo7DGq7VRA6ECRX6C_09oU,18984
19
+ tinygrad/engine/search.py,sha256=rLS1Aj_r42L0_uUJRyytqFTpJvXxPdx0LRD7BzvF2dQ,11359
20
+ tinygrad/nn/__init__.py,sha256=-CH1lNAoxJQjFeGFpp78RYLaD2sD4Q4biJA5uvWwm88,12942
21
+ tinygrad/nn/datasets.py,sha256=3tau-L-HiopP1leFDXZV-VamCR0DxOhAu6wVOkAVSZY,493
22
+ tinygrad/nn/optim.py,sha256=zf85kwumpS17fk1NBjUfe7tOUE7-XH37SL-LjgAfva8,6799
23
+ tinygrad/nn/state.py,sha256=cYxsy6IiMCl1wITZcM04GirsftFOutp-2X0PIow35aY,10111
24
+ tinygrad/renderer/__init__.py,sha256=OVEBlNy-eaREIn9jrY0tBhVsmMYi1eZK3wW_P_f26-M,2935
25
+ tinygrad/renderer/assembly.py,sha256=syUM3dMOrmIEkmVWzOS47E2sZXDUWbVXmkofl0RnnWw,18011
26
+ tinygrad/renderer/cstyle.py,sha256=O2MNxsxZzMCPY7JAC_5w5aUqosJSi4LGdH8PZoA7s2A,24387
27
+ tinygrad/renderer/llvmir.py,sha256=43iw-cfwukwffibuZQOm3jI3fPPcRT8rs4-JVglVIHU,10192
28
+ tinygrad/runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
+ tinygrad/runtime/ops_amd.py,sha256=h4ZwP64C82aNuHTUtjRykiEzsdhUwGKtvfNdsXOUXyI,31393
30
+ tinygrad/runtime/ops_clang.py,sha256=XWqwobReRdu-Tj-chbWEJFMx6AQfgdGCcpdWcLWUTOQ,1468
31
+ tinygrad/runtime/ops_cuda.py,sha256=LujFWnZvLEaPlebomy9c3C4WAbbaEswdeuJgCNDIiGg,10860
32
+ tinygrad/runtime/ops_disk.py,sha256=5AGIXyGb_x3zuPQ54M6p0xkROjtQo0MMeiqSUPnQ6MQ,6655
33
+ tinygrad/runtime/ops_gpu.py,sha256=z9MPEQawWbJGfLyN3KOgxXaENEbi3ULm8xMWcrln1c0,7532
34
+ tinygrad/runtime/ops_llvm.py,sha256=dODiyVSlPofHyDIZrD-V74h8W1d94VPnr_-A4gNbSO4,2229
35
+ tinygrad/runtime/ops_metal.py,sha256=cb4xIXp_4wyMjk7E2XuSwcCzh1JEwc21p6hJxB8LPgc,5802
36
+ tinygrad/runtime/ops_npy.py,sha256=HSDb47vcbPvDiG3EmSGPC0W-EPO9O5R0TqlQKnYxANU,404
37
+ tinygrad/runtime/ops_nv.py,sha256=ClqazkWrMOheiyRDl1ZY7azMPqFHstnEhlAEKNSkklg,37657
38
+ tinygrad/runtime/ops_python.py,sha256=2B1ty1fr4SpPq-bfEyA-gDxxvjI7ulS1EAqHwHFn-QE,11051
39
+ tinygrad/runtime/autogen/amd_gpu.py,sha256=kqu1ygAUxymdwFznsIrXnsJYDa0UDpVndTlTuj9XK_4,581516
40
+ tinygrad/runtime/autogen/comgr.py,sha256=mhQtuF_vGDrJZD3iyQ_38FrS_2jp3WtImNZPC4TBuAg,39793
41
+ tinygrad/runtime/autogen/cuda.py,sha256=GgRl4AfU54JG0G1XJj2dq2FbrUZ8XG_AnFrPAZJpSSg,250380
42
+ tinygrad/runtime/autogen/hip.py,sha256=1yUHDCwL3KkD15if2Q1Ud3GbJiR7DxsNorKZTCINw54,245532
43
+ tinygrad/runtime/autogen/hsa.py,sha256=7Hsrn17HmChyeFOSX_3Fnzl9c0COtq2Z2ExqGu5FNiU,277716
44
+ tinygrad/runtime/autogen/io_uring.py,sha256=cWYzokworiNV1x4I6eBmZpqV2nCt5B_rVEDx7pFP2W0,57664
45
+ tinygrad/runtime/autogen/kfd.py,sha256=dDmLFL6HL_QXRW2rZOCArY55PRuXuLN9563XCipV2jM,29935
46
+ tinygrad/runtime/autogen/nv_gpu.py,sha256=Hd8p-ay8dbdqFwvGyEtXLX3Ymjr3DP1XSRMAe0zKrq4,1686859
47
+ tinygrad/runtime/autogen/opencl.py,sha256=aW-luGFF5PXFEZ6SgrGANhA9qpkI-fZyEsmDfpez2Ss,82654
48
+ tinygrad/runtime/driver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
+ tinygrad/runtime/driver/hip_comgr.py,sha256=rCS6DaUjkYflI1OvP5XGCOvcy67XCAgeSvQMUb1K9iY,3840
50
+ tinygrad/runtime/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
+ tinygrad/runtime/graph/clang.py,sha256=NjobP8AVzNC-CS25FjO5gl2-B8KGJb8K8kLzNXEycPg,2005
52
+ tinygrad/runtime/graph/cuda.py,sha256=ZXOhUhTzz3L2ut-E0tCTrqTEcDFExx_FaZzBm8HcQIk,5268
53
+ tinygrad/runtime/graph/hcq.py,sha256=iJcAw-95PXwjkNgQbwwhFmq1m9Zprq8wwPNFkil_SEQ,11877
54
+ tinygrad/runtime/graph/metal.py,sha256=bwB6uAsqjEbwv5ML5ziWduBtmTpseJboo6J9ssVa4v4,4579
55
+ tinygrad/shape/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
+ tinygrad/shape/shapetracker.py,sha256=oJqr3q9HoDwNN8S2VV9H_IpmG_ovGY10CTSaWWZjUz4,6387
57
+ tinygrad/shape/symbolic.py,sha256=ORyH9bkb94u3ak0XvM7TvZj2n-_RlR9FcYlQqmIUc2M,16617
58
+ tinygrad/shape/view.py,sha256=H9tdtNWD1xKci_Qhf3EJsWKPD8x4M-e7MAxgoUDqG9k,17999
59
+ tinygrad-0.9.1.dist-info/LICENSE,sha256=ABRhUPEILzINYIukgazD-_rPipkUNUwslrb0RxnV6Xc,1058
60
+ tinygrad-0.9.1.dist-info/METADATA,sha256=YApgO_KsHfOTp9m10ytDj_VIVghIq5GP9PTIiIhre2c,10990
61
+ tinygrad-0.9.1.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
62
+ tinygrad-0.9.1.dist-info/top_level.txt,sha256=vDABMCWBFQnx2kn9Azueu88FP-1klQdePoHikQhHymc,9
63
+ tinygrad-0.9.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: setuptools (70.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,143 +0,0 @@
1
- import ctypes, collections
2
- import tinygrad.runtime.autogen.hsa as hsa
3
- from tinygrad.helpers import init_c_var
4
-
5
- def check(status):
6
- if status != 0:
7
- hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
8
- raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
9
-
10
- # Precalulated AQL info
11
- AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
12
- EMPTY_SIGNAL = hsa.hsa_signal_t()
13
-
14
- DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
15
- DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
16
- DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
17
- DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
18
- DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
19
-
20
- BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
21
- BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
22
- BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
23
- BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
24
-
25
- class AQLQueue:
26
- def __init__(self, device, sz=-1):
27
- self.device = device
28
-
29
- check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
30
- queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
31
-
32
- null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
33
- self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
34
- hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
35
-
36
- self.next_doorbell_index = 0
37
- self.queue_base = self.hw_queue.contents.base_address
38
- self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
39
- self.write_addr = self.queue_base
40
- self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
41
- self.available_packet_slots = self.hw_queue.contents.size
42
-
43
- check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
44
- check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
45
-
46
- def __del__(self):
47
- if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
48
-
49
- def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
50
- if self.available_packet_slots == 0: self._wait_queue()
51
-
52
- packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
53
- packet.workgroup_size_x = local_size[0]
54
- packet.workgroup_size_y = local_size[1]
55
- packet.workgroup_size_z = local_size[2]
56
- packet.reserved0 = 0
57
- packet.grid_size_x = global_size[0] * local_size[0]
58
- packet.grid_size_y = global_size[1] * local_size[1]
59
- packet.grid_size_z = global_size[2] * local_size[2]
60
- packet.private_segment_size = prg.private_segment_size
61
- packet.group_segment_size = prg.group_segment_size
62
- packet.kernel_object = prg.handle
63
- packet.kernarg_address = kernargs
64
- packet.reserved2 = 0
65
- packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
66
- packet.setup = DISPATCH_KERNEL_SETUP
67
- packet.header = DISPATCH_KERNEL_HEADER
68
- self._submit_packet()
69
-
70
- def submit_barrier(self, wait_signals=None, completion_signal=None):
71
- assert wait_signals is None or len(wait_signals) <= 5
72
- if self.available_packet_slots == 0: self._wait_queue()
73
-
74
- packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
75
- packet.reserved0 = 0
76
- packet.reserved1 = 0
77
- for i in range(5):
78
- packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
79
- packet.reserved2 = 0
80
- packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
81
- packet.header = BARRIER_HEADER
82
- self._submit_packet()
83
-
84
- def blit_packets(self, packet_addr, packet_cnt):
85
- if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
86
-
87
- tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
88
- rem_packet_cnt = packet_cnt - tail_blit_packets
89
- ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
90
- if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
91
-
92
- self._submit_packet(packet_cnt)
93
-
94
- def wait(self):
95
- self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
96
- hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
97
- self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
98
-
99
- def _wait_queue(self, need_packets=1):
100
- while self.available_packet_slots < need_packets:
101
- rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
102
- self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
103
-
104
- def _submit_packet(self, cnt=1):
105
- self.available_packet_slots -= cnt
106
- self.next_doorbell_index += cnt
107
- hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
108
- hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
109
-
110
- self.write_addr += AQL_PACKET_SIZE * cnt
111
- if self.write_addr > self.write_addr_end:
112
- self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
113
-
114
- def scan_agents():
115
- agents = collections.defaultdict(list)
116
-
117
- @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
118
- def __scan_agents(agent, data):
119
- status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
120
- if status == 0: agents[device_type.value].append(agent)
121
- return hsa.HSA_STATUS_SUCCESS
122
-
123
- hsa.hsa_iterate_agents(__scan_agents, None)
124
- return agents
125
-
126
- def find_memory_pool(agent, segtyp=-1, location=-1):
127
- @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
128
- def __filter_amd_memory_pools(mem_pool, data):
129
- check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
130
- if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
131
-
132
- check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
133
- if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
134
-
135
- check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
136
- if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
137
-
138
- ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
139
- ret[0] = mem_pool
140
- return hsa.HSA_STATUS_INFO_BREAK
141
-
142
- hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
143
- return region
@@ -1,171 +0,0 @@
1
- import ctypes, collections, time, itertools
2
- from typing import List, Any, Dict, cast, Optional, Tuple
3
- from tinygrad.helpers import GraphException, init_c_var, round_up
4
- from tinygrad.device import Buffer, BufferOptions
5
- from tinygrad.device import Compiled, Device
6
- from tinygrad.shape.symbolic import Variable
7
- from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
8
- from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
9
- from tinygrad.engine.jit import MultiGraphRunner
10
- import tinygrad.runtime.autogen.hsa as hsa
11
- from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
12
-
13
- def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
14
-
15
- class VirtAQLQueue(AQLQueue):
16
- def __init__(self, device, sz):
17
- self.device = device
18
- self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
19
- self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
20
- self.packets_count = 0
21
- self.available_packet_slots = sz
22
- def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
23
- def _submit_packet(self):
24
- self.write_addr += AQL_PACKET_SIZE
25
- self.packets_count += 1
26
- self.available_packet_slots -= 1
27
-
28
- class HSAGraph(MultiGraphRunner):
29
- def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
30
- super().__init__(jit_cache, input_rawbuffers, var_vals)
31
-
32
- # Check all jit items are compatible.
33
- compiled_devices = set()
34
- for ji in self.jit_cache:
35
- if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
36
- elif isinstance(ji.prg, BufferXfer):
37
- for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
38
- else: raise GraphException
39
- if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
40
-
41
- self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
42
-
43
- # Allocate kernel args.
44
- kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
45
- for ji in self.jit_cache:
46
- if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
47
- kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
48
-
49
- # Fill initial arguments.
50
- self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
51
- for j,ji in enumerate(self.jit_cache):
52
- if not isinstance(ji.prg, CompiledRunner): continue
53
- self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
54
- kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
55
- for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
56
- for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
57
-
58
- # Build queues.
59
- self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
60
- self.packets = {}
61
- self.transfers = []
62
- self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
63
- self.signals_to_reset: List[hsa.hsa_signal_t] = []
64
- self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
65
- self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
66
-
67
- # Special packet to wait for the world.
68
- self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
69
- for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
70
-
71
- for j,ji in enumerate(self.jit_cache):
72
- if isinstance(ji.prg, CompiledRunner):
73
- wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
74
- for i in range(0, len(wait_signals), 5):
75
- self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
76
- self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
77
-
78
- sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
79
- self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
80
- ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
81
- if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
82
- elif isinstance(ji.prg, BufferXfer):
83
- dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
84
- dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
85
- sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
86
-
87
- wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
88
- self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
89
- (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
90
- self.ji_to_transfer[j] = len(self.transfers) - 1
91
- if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
92
-
93
- # Wait for all active signals to finish the graph
94
- wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
95
- for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
96
- for dev in self.signals_to_devices[v.handle]:
97
- wait_signals_to_finish[dev].append(v)
98
-
99
- self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
100
- for dev in self.devices:
101
- wait_signals = wait_signals_to_finish[dev]
102
- for i in range(0, max(1, len(wait_signals)), 5):
103
- self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
104
-
105
- # Zero signals to allow graph to start and execute.
106
- for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
107
- hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
108
-
109
- def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
110
- # Wait and restore signals
111
- hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
112
- for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
113
- hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
114
-
115
- # Update rawbuffers
116
- for (j,i),input_idx in self.input_replace.items():
117
- if j in self.ji_kargs_structs:
118
- self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
119
- else:
120
- if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
121
- elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
122
-
123
- # Update var_vals
124
- for j in self.jc_idx_with_updatable_var_vals:
125
- for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
126
- self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
127
-
128
- # Update launch dims
129
- for j in self.jc_idx_with_updatable_launch_dims:
130
- gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
131
- self.packets[j].workgroup_size_x = lc[0]
132
- self.packets[j].workgroup_size_y = lc[1]
133
- self.packets[j].workgroup_size_z = lc[2]
134
- self.packets[j].grid_size_x = gl[0] * lc[0]
135
- self.packets[j].grid_size_y = gl[1] * lc[1]
136
- self.packets[j].grid_size_z = gl[2] * lc[2]
137
-
138
- for dev in self.devices:
139
- dev.flush_hdp()
140
- dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
141
-
142
- for transfer_data in self.transfers:
143
- check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
144
-
145
- et = None
146
- if wait:
147
- st = time.perf_counter()
148
- hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
149
- et = time.perf_counter() - st
150
-
151
- for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
152
- return et
153
-
154
- def alloc_signal(self, reset_on_start=False, wait_on=None):
155
- sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
156
- if reset_on_start: self.signals_to_reset.append(sync_signal)
157
- if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
158
- return sync_signal
159
-
160
- def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
161
- if isinstance(dep, hsa.hsa_signal_t): return dep
162
- elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
163
- if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
164
- return packet.completion_signal
165
- return None
166
-
167
- def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
168
- rdeps = self._access_resources(read, write, new_dependency)
169
- wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
170
- if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
171
- return dedup_signals(wait_signals)