tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- Copyright (c) 2023 George Hotz
1
+ Copyright (c) 2024, the tiny corp
2
2
 
3
3
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
4
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tinygrad
3
- Version: 0.8.0
3
+ Version: 0.9.1
4
4
  Summary: You like pytorch? You like micrograd? You love tinygrad! <3
5
5
  Author: George Hotz
6
6
  License: MIT
@@ -10,12 +10,16 @@ Requires-Python: >=3.8
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Requires-Dist: numpy
13
- Requires-Dist: tqdm
14
- Requires-Dist: gpuctypes
15
13
  Requires-Dist: pyobjc-framework-Metal ; platform_system == "Darwin"
16
14
  Requires-Dist: pyobjc-framework-libdispatch ; platform_system == "Darwin"
17
15
  Provides-Extra: arm
18
16
  Requires-Dist: unicorn ; extra == 'arm'
17
+ Provides-Extra: docs
18
+ Requires-Dist: mkdocs-material ; extra == 'docs'
19
+ Requires-Dist: mkdocstrings[python] ; extra == 'docs'
20
+ Requires-Dist: markdown-callouts ; extra == 'docs'
21
+ Requires-Dist: markdown-exec[ansi] ; extra == 'docs'
22
+ Requires-Dist: black ; extra == 'docs'
19
23
  Provides-Extra: linting
20
24
  Requires-Dist: pylint ; extra == 'linting'
21
25
  Requires-Dist: mypy ; extra == 'linting'
@@ -30,10 +34,11 @@ Requires-Dist: torch ; extra == 'testing'
30
34
  Requires-Dist: pillow ; extra == 'testing'
31
35
  Requires-Dist: pytest ; extra == 'testing'
32
36
  Requires-Dist: pytest-xdist ; extra == 'testing'
33
- Requires-Dist: onnx ==1.15.0 ; extra == 'testing'
37
+ Requires-Dist: onnx ==1.16.0 ; extra == 'testing'
34
38
  Requires-Dist: onnx2torch ; extra == 'testing'
35
39
  Requires-Dist: opencv-python ; extra == 'testing'
36
40
  Requires-Dist: tabulate ; extra == 'testing'
41
+ Requires-Dist: tqdm ; extra == 'testing'
37
42
  Requires-Dist: safetensors ; extra == 'testing'
38
43
  Requires-Dist: transformers ; extra == 'testing'
39
44
  Requires-Dist: sentencepiece ; extra == 'testing'
@@ -41,18 +46,26 @@ Requires-Dist: tiktoken ; extra == 'testing'
41
46
  Requires-Dist: librosa ; extra == 'testing'
42
47
  Requires-Dist: networkx ; extra == 'testing'
43
48
  Requires-Dist: hypothesis ; extra == 'testing'
49
+ Requires-Dist: nibabel ; extra == 'testing'
50
+ Requires-Dist: bottle ; extra == 'testing'
51
+ Provides-Extra: testing_tf
52
+ Requires-Dist: tensorflow ==2.15.1 ; extra == 'testing_tf'
53
+ Requires-Dist: tensorflow-addons ; extra == 'testing_tf'
44
54
  Provides-Extra: triton
45
55
  Requires-Dist: triton-nightly >=2.1.0.dev20231014192330 ; extra == 'triton'
46
56
 
47
57
  <div align="center">
48
58
 
49
- [![logo](https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png)](https://tinygrad.org)
59
+ <picture>
60
+ <source media="(prefers-color-scheme: light)" srcset="/docs/logo_tiny_light.svg">
61
+ <img alt="tiny corp logo" src="/docs/logo_tiny_dark.svg" width="50%" height="50%">
62
+ </picture>
50
63
 
51
64
  tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) and [karpathy/micrograd](https://github.com/karpathy/micrograd). Maintained by [tiny corp](https://tinygrad.org).
52
65
 
53
66
  <h3>
54
67
 
55
- [Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](/docs) | [Examples](/examples) | [Showcase](/docs/showcase.md) | [Discord](https://discord.gg/ZjZadyC7PK)
68
+ [Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](https://docs.tinygrad.org/) | [Discord](https://discord.gg/ZjZadyC7PK)
56
69
 
57
70
  </h3>
58
71
 
@@ -122,17 +135,15 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers
122
135
 
123
136
  tinygrad already supports numerous accelerators, including:
124
137
 
125
- - [x] [CPU](tinygrad/runtime/ops_cpu.py)
126
138
  - [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
127
- - [x] [C Code (Clang)](tinygrad/runtime/ops_clang.py)
139
+ - [x] [CLANG (C Code)](tinygrad/runtime/ops_clang.py)
128
140
  - [x] [LLVM](tinygrad/runtime/ops_llvm.py)
129
141
  - [x] [METAL](tinygrad/runtime/ops_metal.py)
130
142
  - [x] [CUDA](tinygrad/runtime/ops_cuda.py)
131
- - [x] [PyTorch](tinygrad/runtime/ops_torch.py)
132
- - [x] [HIP](tinygrad/runtime/ops_hip.py)
143
+ - [x] [AMD](tinygrad/runtime/ops_amd.py)
144
+ - [x] [NV](tinygrad/runtime/ops_nv.py)
133
145
 
134
146
  And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
135
- More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md).
136
147
 
137
148
  ## Installation
138
149
 
@@ -154,7 +165,7 @@ python3 -m pip install git+https://github.com/tinygrad/tinygrad.git
154
165
 
155
166
  ## Documentation
156
167
 
157
- Documentation along with a quick start guide can be found in the [docs/](/docs) directory.
168
+ Documentation along with a quick start guide can be found on the [docs website](https://docs.tinygrad.org/) built from the [docs/](/docs) directory.
158
169
 
159
170
  ### Quick example comparing to PyTorch
160
171
 
@@ -193,13 +204,14 @@ We'll start with what will get your PR closed with a pointer to this section:
193
204
  - All docs and whitespace changes will be closed unless you are a well-known contributor. The people writing the docs should be those who know the codebase the absolute best. People who have not demonstrated that shouldn't be messing with docs. Whitespace changes are both useless *and* carry a risk of introducing bugs.
194
205
  - Anything you claim is a "speedup" must be benchmarked. In general, the goal is simplicity, so even if your PR makes things marginally faster, you have to consider the tradeoff with maintainablity and readablity.
195
206
  - In general, the code outside the core `tinygrad/` folder is not well tested, so unless the current code there is broken, you shouldn't be changing it.
207
+ - If your PR looks "complex", is a big diff, or adds lots of lines, it won't be reviewed or merged. Consider breaking it up into smaller PRs that are individually clear wins. A common pattern I see is prerequisite refactors before adding new functionality. If you can (cleanly) refactor to the point that the feature is a 3 line change, this is great, and something easy for us to review.
196
208
 
197
209
  Now, what we want:
198
210
 
199
211
  - Bug fixes (with a regression test) are great! This library isn't 1.0 yet, so if you stumble upon a bug, fix it, write a test, and submit a PR, this is valuable work.
200
212
  - Solving bounties! tinygrad [offers cash bounties](https://docs.google.com/spreadsheets/d/1WKHbT-7KOgjEawq5h5Ic1qUWzpfAzuD_J06N1JwOCGs/edit?usp=sharing) for certain improvements to the library. All new code should be high quality and well tested.
201
213
  - Features. However, if you are adding a feature, consider the line tradeoff. If it's 3 lines, there's less of a bar of usefulness it has to meet over something that's 30 or 300 lines. All features must have regression tests. In general with no other constraints, your feature's API should match torch or numpy.
202
- - Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win.
214
+ - Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win. Refactors should pass [process replay](#process-replay-tests).
203
215
  - Tests/fuzzers. If you can add tests that are non brittle, they are welcome. We have some fuzzers in here too, and there's a plethora of bugs that can be found with them and by improving them. Finding bugs, even writing broken tests (that should pass) with `@unittest.expectedFailure` is great. This is how we make progress.
204
216
  - Dead code removal from core `tinygrad/` folder. We don't care about the code in extra, but removing dead code from the core library is great. Less for new people to read and be confused by.
205
217
 
@@ -215,3 +227,9 @@ python3 -m pip install -e '.[testing]' # install extra deps for testing
215
227
  python3 test/test_ops.py # just the ops tests
216
228
  python3 -m pytest test/ # whole test suite
217
229
  ```
230
+
231
+ #### Process replay tests
232
+
233
+ [Process replay](https://github.com/tinygrad/tinygrad/blob/master/test/external/process_replay/process_replay.py) detects changes in the generated kernels of CI tests by comparing them against tinygrad master. If your PR is a refactor or speedup without any expected behavior change, it should include a green process replay pass to get merged.
234
+
235
+ You can enable process replay by adding [run_process_replay] to your PR title. [example](https://github.com/tinygrad/tinygrad/pull/4995). Note that you should keep your branch up-to-date with master.
@@ -0,0 +1,63 @@
1
+ tinygrad/__init__.py,sha256=jC-35zswLSXLuRRThG_o6yar6qQjLCqmeaFCj_XKN08,449
2
+ tinygrad/device.py,sha256=oy8vQAn0bLSVnLtu6MfusWF6Axv83xEwrJB3C1sWcPo,18287
3
+ tinygrad/dtype.py,sha256=LGHdreCNcM-FJTwT3uNrsGIMCSnBHi4ZeGf3K0QHXGU,6199
4
+ tinygrad/function.py,sha256=1PS3nZBicJmMAZ-IvE0SJAHWrGyRMkWM9-J_Hr5az0E,9473
5
+ tinygrad/helpers.py,sha256=SQfdz3AxiAT_k7GExSIXKBYqvCugxaF7ftQ9gr8qvFg,16814
6
+ tinygrad/lazy.py,sha256=mrmxsa6dAQhVAm2UMs7yM18Y-IeJyG4rqyW4-LL5apQ,13336
7
+ tinygrad/multi.py,sha256=EK9OTqmt2EaEgJIIebZUe944NaoJLTFbnUGxVYIk--k,11520
8
+ tinygrad/ops.py,sha256=ZOKpgKY_y3BGqu5NpL172me10Lw4SYVkkCUeD3HR7MI,8607
9
+ tinygrad/tensor.py,sha256=-TRbZ1ydsxgsHUlMDfx6AuEY4qtxUbYhYQ8hK7w8mt0,133764
10
+ tinygrad/codegen/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ tinygrad/codegen/kernel.py,sha256=Bw53Hb0vfdyyz2JQhE7YV7DqfZh3gVsxtLFABkx8Xps,36961
12
+ tinygrad/codegen/linearizer.py,sha256=oxqSuzBwpbOtZL1vTxVR4qQqjdvvAUwayhVMltkkX1s,32964
13
+ tinygrad/codegen/uops.py,sha256=GzEwfeyATi9uydzHnyFeN_V47hrk7pkz3gHePjyp6jY,25454
14
+ tinygrad/engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ tinygrad/engine/graph.py,sha256=2w16XKpAXSkEML8Hl3m1bbPnLDV-MBsaBJLBPx2QzA0,5264
16
+ tinygrad/engine/jit.py,sha256=xoiApaot299mcjKPbxwUlqu2SJ3Tbyfs_rm5zKuHNMA,11355
17
+ tinygrad/engine/realize.py,sha256=LsfZIHQ2Zz_79a4LKBU1pMUwbleM1x_oBTyn4D1qRc8,11214
18
+ tinygrad/engine/schedule.py,sha256=ltOshvCgUU5wHqTDVLhxPVo7DGq7VRA6ECRX6C_09oU,18984
19
+ tinygrad/engine/search.py,sha256=rLS1Aj_r42L0_uUJRyytqFTpJvXxPdx0LRD7BzvF2dQ,11359
20
+ tinygrad/nn/__init__.py,sha256=-CH1lNAoxJQjFeGFpp78RYLaD2sD4Q4biJA5uvWwm88,12942
21
+ tinygrad/nn/datasets.py,sha256=3tau-L-HiopP1leFDXZV-VamCR0DxOhAu6wVOkAVSZY,493
22
+ tinygrad/nn/optim.py,sha256=zf85kwumpS17fk1NBjUfe7tOUE7-XH37SL-LjgAfva8,6799
23
+ tinygrad/nn/state.py,sha256=cYxsy6IiMCl1wITZcM04GirsftFOutp-2X0PIow35aY,10111
24
+ tinygrad/renderer/__init__.py,sha256=OVEBlNy-eaREIn9jrY0tBhVsmMYi1eZK3wW_P_f26-M,2935
25
+ tinygrad/renderer/assembly.py,sha256=syUM3dMOrmIEkmVWzOS47E2sZXDUWbVXmkofl0RnnWw,18011
26
+ tinygrad/renderer/cstyle.py,sha256=O2MNxsxZzMCPY7JAC_5w5aUqosJSi4LGdH8PZoA7s2A,24387
27
+ tinygrad/renderer/llvmir.py,sha256=43iw-cfwukwffibuZQOm3jI3fPPcRT8rs4-JVglVIHU,10192
28
+ tinygrad/runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
+ tinygrad/runtime/ops_amd.py,sha256=h4ZwP64C82aNuHTUtjRykiEzsdhUwGKtvfNdsXOUXyI,31393
30
+ tinygrad/runtime/ops_clang.py,sha256=XWqwobReRdu-Tj-chbWEJFMx6AQfgdGCcpdWcLWUTOQ,1468
31
+ tinygrad/runtime/ops_cuda.py,sha256=LujFWnZvLEaPlebomy9c3C4WAbbaEswdeuJgCNDIiGg,10860
32
+ tinygrad/runtime/ops_disk.py,sha256=5AGIXyGb_x3zuPQ54M6p0xkROjtQo0MMeiqSUPnQ6MQ,6655
33
+ tinygrad/runtime/ops_gpu.py,sha256=z9MPEQawWbJGfLyN3KOgxXaENEbi3ULm8xMWcrln1c0,7532
34
+ tinygrad/runtime/ops_llvm.py,sha256=dODiyVSlPofHyDIZrD-V74h8W1d94VPnr_-A4gNbSO4,2229
35
+ tinygrad/runtime/ops_metal.py,sha256=cb4xIXp_4wyMjk7E2XuSwcCzh1JEwc21p6hJxB8LPgc,5802
36
+ tinygrad/runtime/ops_npy.py,sha256=HSDb47vcbPvDiG3EmSGPC0W-EPO9O5R0TqlQKnYxANU,404
37
+ tinygrad/runtime/ops_nv.py,sha256=ClqazkWrMOheiyRDl1ZY7azMPqFHstnEhlAEKNSkklg,37657
38
+ tinygrad/runtime/ops_python.py,sha256=2B1ty1fr4SpPq-bfEyA-gDxxvjI7ulS1EAqHwHFn-QE,11051
39
+ tinygrad/runtime/autogen/amd_gpu.py,sha256=kqu1ygAUxymdwFznsIrXnsJYDa0UDpVndTlTuj9XK_4,581516
40
+ tinygrad/runtime/autogen/comgr.py,sha256=mhQtuF_vGDrJZD3iyQ_38FrS_2jp3WtImNZPC4TBuAg,39793
41
+ tinygrad/runtime/autogen/cuda.py,sha256=GgRl4AfU54JG0G1XJj2dq2FbrUZ8XG_AnFrPAZJpSSg,250380
42
+ tinygrad/runtime/autogen/hip.py,sha256=1yUHDCwL3KkD15if2Q1Ud3GbJiR7DxsNorKZTCINw54,245532
43
+ tinygrad/runtime/autogen/hsa.py,sha256=7Hsrn17HmChyeFOSX_3Fnzl9c0COtq2Z2ExqGu5FNiU,277716
44
+ tinygrad/runtime/autogen/io_uring.py,sha256=cWYzokworiNV1x4I6eBmZpqV2nCt5B_rVEDx7pFP2W0,57664
45
+ tinygrad/runtime/autogen/kfd.py,sha256=dDmLFL6HL_QXRW2rZOCArY55PRuXuLN9563XCipV2jM,29935
46
+ tinygrad/runtime/autogen/nv_gpu.py,sha256=Hd8p-ay8dbdqFwvGyEtXLX3Ymjr3DP1XSRMAe0zKrq4,1686859
47
+ tinygrad/runtime/autogen/opencl.py,sha256=aW-luGFF5PXFEZ6SgrGANhA9qpkI-fZyEsmDfpez2Ss,82654
48
+ tinygrad/runtime/driver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
+ tinygrad/runtime/driver/hip_comgr.py,sha256=rCS6DaUjkYflI1OvP5XGCOvcy67XCAgeSvQMUb1K9iY,3840
50
+ tinygrad/runtime/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
+ tinygrad/runtime/graph/clang.py,sha256=NjobP8AVzNC-CS25FjO5gl2-B8KGJb8K8kLzNXEycPg,2005
52
+ tinygrad/runtime/graph/cuda.py,sha256=ZXOhUhTzz3L2ut-E0tCTrqTEcDFExx_FaZzBm8HcQIk,5268
53
+ tinygrad/runtime/graph/hcq.py,sha256=iJcAw-95PXwjkNgQbwwhFmq1m9Zprq8wwPNFkil_SEQ,11877
54
+ tinygrad/runtime/graph/metal.py,sha256=bwB6uAsqjEbwv5ML5ziWduBtmTpseJboo6J9ssVa4v4,4579
55
+ tinygrad/shape/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
+ tinygrad/shape/shapetracker.py,sha256=oJqr3q9HoDwNN8S2VV9H_IpmG_ovGY10CTSaWWZjUz4,6387
57
+ tinygrad/shape/symbolic.py,sha256=ORyH9bkb94u3ak0XvM7TvZj2n-_RlR9FcYlQqmIUc2M,16617
58
+ tinygrad/shape/view.py,sha256=H9tdtNWD1xKci_Qhf3EJsWKPD8x4M-e7MAxgoUDqG9k,17999
59
+ tinygrad-0.9.1.dist-info/LICENSE,sha256=ABRhUPEILzINYIukgazD-_rPipkUNUwslrb0RxnV6Xc,1058
60
+ tinygrad-0.9.1.dist-info/METADATA,sha256=YApgO_KsHfOTp9m10ytDj_VIVghIq5GP9PTIiIhre2c,10990
61
+ tinygrad-0.9.1.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
62
+ tinygrad-0.9.1.dist-info/top_level.txt,sha256=vDABMCWBFQnx2kn9Azueu88FP-1klQdePoHikQhHymc,9
63
+ tinygrad-0.9.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.42.0)
2
+ Generator: setuptools (70.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,93 +0,0 @@
1
- from typing import Tuple
2
- from tinygrad.helpers import prod, IMAGE, getenv, DEBUG
3
- from tinygrad.dtype import dtypes
4
-
5
- # *** image Tensor function replacements ***
6
-
7
- def image_dot(self, w):
8
- # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
9
- n1, n2 = len(self.shape), len(w.shape)
10
- assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
11
- assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
12
- bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
13
- out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
14
-
15
- # NOTE: with NHWC we can remove the transposes
16
- # bs x groups*cin x H x W
17
- cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
18
- # groups*cout x cin x H, W
19
- cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
20
- return image_conv2d(cx, cw, groups=groups).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
21
-
22
- def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
23
- base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
24
-
25
- (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
26
- x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
27
-
28
- # hack for non multiples of 4 on cin
29
- if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
30
- x = x.reshape(bs, groups, cin, iy, ix) # do this always?
31
- added_input_channels = 4 - (cin % 4)
32
- w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
33
- x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
34
- cin = cin + added_input_channels
35
- x = x.reshape(bs, groups*cin, iy, ix)
36
-
37
- # hack for non multiples of 4 on rcout
38
- added_output_channels = 0
39
- if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
40
- added_output_channels = 4 - (rcout % 4)
41
- rcout += added_output_channels
42
- cout = groups * rcout
43
- w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
44
-
45
- # packed (note: flipping bs and iy would make the auto-padding work)
46
- x = x.permute(0,2,3,1)
47
- cin_last = iy == 1 and ix == 1
48
- if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
49
- elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
50
- else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
51
-
52
- # contiguous creates the image, and early realize static weights (TODO: test for the static weight)
53
- if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
54
- x, w = x.contiguous(), w.contiguous()
55
-
56
- # expand out
57
- rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
58
- cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
59
- x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
60
- if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
61
- else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
62
-
63
- # padding
64
- padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
65
- x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
66
-
67
- # prepare input
68
- x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
69
- x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
70
-
71
- # prepare weights
72
- w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
73
-
74
- # the conv!
75
- ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1))
76
-
77
- # undo hack for non multiples of 4 on C.rcout
78
- if added_output_channels != 0:
79
- ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
80
- cout = groups * (rcout - added_output_channels)
81
-
82
- # NCHW output
83
- ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
84
- return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
85
-
86
- # *** images have weird indexing requirements ***
87
-
88
- from tinygrad.shape.symbolic import Node
89
- def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
90
- idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
91
- # TODO: bring back the valid removal logic (correct!)
92
- if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
93
- return (idx, idy), valid
@@ -1,103 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Optional, Union, Any, Tuple, List
3
- import functools
4
- from tinygrad.helpers import all_same, dedup
5
- from tinygrad.dtype import DType
6
- from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
7
- from tinygrad.lazy import LazyBuffer, create_schedule
8
- from tinygrad.shape.shapetracker import ShapeTracker, sint
9
-
10
- def all_reduce(lbs):
11
- # TODO: replace this with ring reduce
12
- return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
13
-
14
- def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
15
- assert lbs[0].shape[axis] % len(lbs) == 0, f"{lbs[0].shape=} {axis=} {len(lbs)=}"
16
- sz = lbs[0].shape[axis] // len(lbs)
17
- return [lb.shrink(tuple((0,s) if a != axis else (sz*i,sz*(i+1)) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
18
-
19
- class MultiLazyBuffer:
20
- def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]):
21
- assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs) >= 2, "all lbs must be LazyBuffers, and we need at least two of them"
22
- assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st"
23
- self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs)
24
- self.shape = tuple(s*len(self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape))
25
-
26
- def __repr__(self):
27
- return f"<MLB{chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
28
-
29
- @staticmethod
30
- def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
31
- lbs = [lb.contiguous() if lb.base != lb else lb] * len(devices)
32
- return MultiLazyBuffer([lb.copy_to_device(d).contiguous() for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis)
33
-
34
- def copy_to_device(self, device:str) -> LazyBuffer:
35
- if self.axis is None: return self.lbs[0].copy_to_device(device)
36
- sz = self.lbs[0].shape[self.axis]
37
- llbs = []
38
- for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]):
39
- pad_arg = tuple((0,0) if a != self.axis else (sz*i,(s*len(self.lbs))-sz*(i+1)) for a,s in enumerate(lb.shape))
40
- llbs.append(lb.pad(pad_arg))
41
- return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
42
-
43
- # TODO: fix this
44
- def is_unrealized_contiguous_const(self): return False
45
-
46
- # passthroughs
47
- def schedule(self, seen=None): return create_schedule(self.lbs, seen)
48
- def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis)
49
- def const(self, val:Union[float, int]) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis)
50
- def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis)
51
-
52
- # elementwise is simple
53
- def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
54
- msrcs = (self,)+in_srcs
55
- assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
56
- assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
57
-
58
- # NOTE: they all have to share an axis, we always choose [-1]
59
- axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
60
- srcs = []
61
- for mlb in msrcs:
62
- if mlb.axis == axis: srcs.append(mlb.lbs)
63
- elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
64
- else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
65
- return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis)
66
-
67
- def _shape_to_single_shard(self, shape): return tuple(s//len(self.lbs) if a == self.axis else s for a,s in enumerate(shape))
68
-
69
- def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer:
70
- if self.axis is not None and new_shape[self.axis] == 1:
71
- # all-reduce on sharded axes
72
- return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None)
73
- # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
74
- return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape)) for x in self.lbs], self.axis)
75
-
76
- # *** movement ops ***
77
-
78
- def reshape(self, arg:Tuple[sint, ...]):
79
- if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None)
80
- # TODO: this can be wrong
81
- st = ShapeTracker.from_shape(self.shape)
82
- rs = st.real_strides()[self.axis]
83
- new_axis = st.reshape(arg).real_strides().index(rs)
84
- narg = tuple(s//len(self.lbs) if a == new_axis else s for a,s in enumerate(arg))
85
- return MultiLazyBuffer([x.reshape(narg) for x in self.lbs], new_axis)
86
-
87
- def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
88
- assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis"
89
- return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis)
90
- def expand(self, arg:Tuple[sint, ...]):
91
- # NOTE: this assert isn't needed, sharded axis can have dim 1
92
- assert self.axis is None or arg[self.axis] == self.lbs[0].shape[self.axis] * len(self.lbs), "expand not supported on sharded axis"
93
- return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg)) for x in self.lbs], self.axis)
94
- def permute(self, arg:Tuple[int, ...]):
95
- # all permutes supported!
96
- return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None)
97
- def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
98
- assert self.axis is None or arg[self.axis] == (0, self.lbs[0].shape[self.axis] * len(self.lbs)), "shrinking not supported on sharded axis"
99
- narg = tuple((s1//len(self.lbs), s2//len(self.lbs)) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))
100
- return MultiLazyBuffer([x.shrink(narg) for x in self.lbs], self.axis)
101
- def stride(self, arg:Tuple[int, ...]):
102
- assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
103
- return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis)
@@ -1,160 +0,0 @@
1
- from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
2
- import itertools, random, math, time, multiprocessing, traceback, signal
3
- from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner
4
- from tinygrad.ops import MemBuffer, LazyOp
5
- from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
6
- from tinygrad.dtype import ImageDType
7
- from tinygrad.codegen.linearizer import Linearizer
8
- from collections import defaultdict
9
- from tinygrad.tensor import Tensor
10
- from tinygrad.shape.symbolic import sym_infer
11
-
12
- from tinygrad.codegen.kernel import Opt, OptOps
13
- actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7] for axis in range(6)]
14
- actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4] for axis in range(4)]
15
- actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
16
- actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256] for axis in range(3)]
17
- actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
18
- actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
19
- Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),]
20
- if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
21
-
22
- def _get_test_global_size(global_size, max_global_size, var_vals):
23
- test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
24
- while prod(test_global_size) > max_global_size:
25
- for j in range(len(global_size)-1,-1,-1):
26
- if test_global_size[j] > 16:
27
- test_global_size[j] //= 2
28
- factor *= 2
29
- break
30
- return test_global_size, factor
31
-
32
- def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): # noqa: E501
33
- factor = 1
34
- if global_size is not None and max_global_size is not None:
35
- global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
36
- car = CompiledASTRunner(ast, name, "", lib, global_size, local_size).build(rdev.runtime)
37
- tms = []
38
- for _ in range(cnt):
39
- if clear_l2:
40
- with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
41
- tms.append(car(rawbufs, var_vals, wait=True, do_update_stats=False)*factor)
42
- if early_stop is not None and early_stop < tms[-1]: break
43
- return tms
44
-
45
- def _compile_linearizer(rdev:Compiled, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]]]:
46
- lin.linearize()
47
- src = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
48
- return rdev.compiler(src), lin.global_size, lin.local_size
49
-
50
- def _try_compile_linearized_w_idx(x):
51
- try: return (x[0], _compile_linearizer(cast(Compiled, Device[Device.DEFAULT]), x[1], "test"))
52
- except Exception:
53
- if DEBUG >= 4: traceback.print_exc()
54
- return (x[0], None)
55
-
56
- # workers should ignore ctrl c
57
- def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
58
-
59
- # *** external API ***
60
-
61
- # get (scrap) buffers for timing the linearizer
62
- def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
63
- bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
64
- for x in lin.membufs: bufsts[x.idx].append(x)
65
- rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
66
- for k,lx in bufsts.items():
67
- buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
68
- rawbufs[k] = Buffer(Device.DEFAULT, buf_size, lx[0].dtype)
69
- assert all(r is not None for r in rawbufs)
70
- return cast(List[Buffer], rawbufs)
71
-
72
- # get dictionary of all possible actions
73
- def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
74
- acted_lins = {0:lin} if include_0 else {}
75
- for i,a in enumerate(actions):
76
- if a.axis is not None and a.axis >= lin.shape_len: continue
77
- if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
78
- lin2 = lin.copy()
79
- try:
80
- lin2.apply_opt(a)
81
- up, lcl = 1, 1
82
- for s,c in zip(lin2.full_shape, lin2.colors()):
83
- if c in {"magenta", "yellow"}: up *= s
84
- if c in {"cyan", "green", "white"}: lcl *= s
85
- if up > 256 or lcl > 256: continue
86
- acted_lins[i+1] = lin2
87
- except Exception:
88
- pass
89
- return acted_lins
90
-
91
- def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
92
- key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
93
- if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
94
- ret = lin.copy()
95
- for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
96
- return ret
97
-
98
- beam: List[Tuple[Linearizer, float]] = []
99
- seen_libs = set()
100
-
101
- default_parallel = 1 if Device.DEFAULT in {"CUDA", "HIP"} else 0
102
- pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker) if getenv("PARALLEL", default_parallel) else None
103
-
104
- try:
105
- var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
106
- exiting, st = False, time.perf_counter()
107
- dev = Device[Device.DEFAULT]
108
- assert isinstance(dev, Compiled)
109
- while not exiting:
110
- acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
111
- timed_lins: List[Tuple[Linearizer, float]] = []
112
- for i,proc in (pool.imap_unordered(_try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(_try_compile_linearized_w_idx, enumerate(acted_lins))): # noqa: E501
113
- if proc is None: continue
114
- lib, global_size, local_size = proc
115
- if lib in seen_libs: continue
116
- seen_libs.add(lib)
117
- tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
118
- timed_lins.append((acted_lins[i], min(tms)))
119
- if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
120
-
121
- # done
122
- opts = sorted(timed_lins, key=lambda x: x[1])
123
- exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
124
- if not exiting: beam = opts[:amt]
125
- assert len(beam) > 0, "no BEAM items succeeded?!?"
126
- if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
127
- if pool is not None: pool.close() # the pool is closed
128
- except KeyboardInterrupt as e:
129
- if pool is not None: pool.terminate()
130
- raise e
131
-
132
- if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
133
- if DEBUG >= 3: print(beam[0][0].applied_opts)
134
- return beam[0][0]
135
-
136
- def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
137
- test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
138
- MAX_WORKGROUP = 1024
139
- local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
140
- local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
141
- def try_exec(local_size):
142
- try: return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
143
- except Exception: return float('inf')
144
- ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
145
- assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
146
- return ret[1]
147
-
148
- def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
149
- key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501
150
- if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
151
-
152
- dev = Device[Device.DEFAULT]
153
- assert isinstance(dev, Compiled)
154
-
155
- var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
156
- lib, global_size, local_size = _compile_linearizer(dev, lin)
157
- tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
158
-
159
- if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
160
- return min(tms)