tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- Copyright (c) 2023 George Hotz
1
+ Copyright (c) 2024, the tiny corp
2
2
 
3
3
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
4
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tinygrad
3
- Version: 0.8.0
3
+ Version: 0.9.0
4
4
  Summary: You like pytorch? You like micrograd? You love tinygrad! <3
5
5
  Author: George Hotz
6
6
  License: MIT
@@ -11,11 +11,16 @@ Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Requires-Dist: numpy
13
13
  Requires-Dist: tqdm
14
- Requires-Dist: gpuctypes
15
14
  Requires-Dist: pyobjc-framework-Metal ; platform_system == "Darwin"
16
15
  Requires-Dist: pyobjc-framework-libdispatch ; platform_system == "Darwin"
17
16
  Provides-Extra: arm
18
17
  Requires-Dist: unicorn ; extra == 'arm'
18
+ Provides-Extra: docs
19
+ Requires-Dist: mkdocs-material ; extra == 'docs'
20
+ Requires-Dist: mkdocstrings[python] ; extra == 'docs'
21
+ Requires-Dist: markdown-callouts ; extra == 'docs'
22
+ Requires-Dist: markdown-exec[ansi] ; extra == 'docs'
23
+ Requires-Dist: black ; extra == 'docs'
19
24
  Provides-Extra: linting
20
25
  Requires-Dist: pylint ; extra == 'linting'
21
26
  Requires-Dist: mypy ; extra == 'linting'
@@ -30,7 +35,7 @@ Requires-Dist: torch ; extra == 'testing'
30
35
  Requires-Dist: pillow ; extra == 'testing'
31
36
  Requires-Dist: pytest ; extra == 'testing'
32
37
  Requires-Dist: pytest-xdist ; extra == 'testing'
33
- Requires-Dist: onnx ==1.15.0 ; extra == 'testing'
38
+ Requires-Dist: onnx ==1.16.0 ; extra == 'testing'
34
39
  Requires-Dist: onnx2torch ; extra == 'testing'
35
40
  Requires-Dist: opencv-python ; extra == 'testing'
36
41
  Requires-Dist: tabulate ; extra == 'testing'
@@ -41,12 +46,19 @@ Requires-Dist: tiktoken ; extra == 'testing'
41
46
  Requires-Dist: librosa ; extra == 'testing'
42
47
  Requires-Dist: networkx ; extra == 'testing'
43
48
  Requires-Dist: hypothesis ; extra == 'testing'
49
+ Requires-Dist: nibabel ; extra == 'testing'
50
+ Provides-Extra: testing_tf
51
+ Requires-Dist: tensorflow ==2.15.1 ; extra == 'testing_tf'
52
+ Requires-Dist: tensorflow-addons ; extra == 'testing_tf'
44
53
  Provides-Extra: triton
45
54
  Requires-Dist: triton-nightly >=2.1.0.dev20231014192330 ; extra == 'triton'
46
55
 
47
56
  <div align="center">
48
57
 
49
- [![logo](https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png)](https://tinygrad.org)
58
+ <picture>
59
+ <source media="(prefers-color-scheme: light)" srcset="/docs/logo_tiny_light.svg">
60
+ <img alt="tiny corp logo" src="/docs/logo_tiny_dark.svg" width="50%" height="50%">
61
+ </picture>
50
62
 
51
63
  tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) and [karpathy/micrograd](https://github.com/karpathy/micrograd). Maintained by [tiny corp](https://tinygrad.org).
52
64
 
@@ -122,17 +134,14 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers
122
134
 
123
135
  tinygrad already supports numerous accelerators, including:
124
136
 
125
- - [x] [CPU](tinygrad/runtime/ops_cpu.py)
126
137
  - [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
127
- - [x] [C Code (Clang)](tinygrad/runtime/ops_clang.py)
138
+ - [x] [CLANG (C Code)](tinygrad/runtime/ops_clang.py)
128
139
  - [x] [LLVM](tinygrad/runtime/ops_llvm.py)
129
140
  - [x] [METAL](tinygrad/runtime/ops_metal.py)
130
141
  - [x] [CUDA](tinygrad/runtime/ops_cuda.py)
131
- - [x] [PyTorch](tinygrad/runtime/ops_torch.py)
132
- - [x] [HIP](tinygrad/runtime/ops_hip.py)
142
+ - [x] [HSA](tinygrad/runtime/ops_hsa.py)
133
143
 
134
144
  And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
135
- More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md).
136
145
 
137
146
  ## Installation
138
147
 
@@ -193,6 +202,7 @@ We'll start with what will get your PR closed with a pointer to this section:
193
202
  - All docs and whitespace changes will be closed unless you are a well-known contributor. The people writing the docs should be those who know the codebase the absolute best. People who have not demonstrated that shouldn't be messing with docs. Whitespace changes are both useless *and* carry a risk of introducing bugs.
194
203
  - Anything you claim is a "speedup" must be benchmarked. In general, the goal is simplicity, so even if your PR makes things marginally faster, you have to consider the tradeoff with maintainablity and readablity.
195
204
  - In general, the code outside the core `tinygrad/` folder is not well tested, so unless the current code there is broken, you shouldn't be changing it.
205
+ - If your PR looks "complex", is a big diff, or adds lots of lines, it won't be reviewed or merged. Consider breaking it up into smaller PRs that are individually clear wins. A common pattern I see is prerequisite refactors before adding new functionality. If you can (cleanly) refactor to the point that the feature is a 3 line change, this is great, and something easy for us to review.
196
206
 
197
207
  Now, what we want:
198
208
 
@@ -0,0 +1,60 @@
1
+ tinygrad/__init__.py,sha256=jC-35zswLSXLuRRThG_o6yar6qQjLCqmeaFCj_XKN08,449
2
+ tinygrad/device.py,sha256=zXcrFjBsiV1rW0aXupszDjD98TWLHin7u8pBd5fdJqo,10446
3
+ tinygrad/dtype.py,sha256=xg2BlFIPcQw0onHW_0ktGXjved9SXgQcLNrqe6gCXto,6221
4
+ tinygrad/function.py,sha256=0xkWst2tRsOeN6YcQS65MfVfWwKQYFkAacgkTys0VdQ,9616
5
+ tinygrad/helpers.py,sha256=XI8MIeBE35wQ4q0NEsUCvkj3QdY0adI80SfCbOySOVI,12773
6
+ tinygrad/lazy.py,sha256=xqaEqXaIpt_77SP_2U6Pyfw8YeGd_0PzNDXJOOnRJ24,13379
7
+ tinygrad/multi.py,sha256=gyGXYVviaPfzAkoJjLFUiusVd3no6HRJunOOxD0DaaY,11362
8
+ tinygrad/ops.py,sha256=aNk1jLuJl--Z_u8DE8Du1iN1AFhMH-6Le5mnvRHLDvI,7124
9
+ tinygrad/tensor.py,sha256=nznRGHH7-64bMpFeu8gvdSDfJ5CEEoezIcHSxPgeJ7k,129223
10
+ tinygrad/codegen/kernel.py,sha256=RRRmOX3iOOgu5ISABW_UTVh5vfGdLFfK1UOtKtpghuY,38169
11
+ tinygrad/codegen/linearizer.py,sha256=jxwEcxxcpWOvYlIgXmlGkgNYZ4sDykiY_5Ve3a8tpYg,27622
12
+ tinygrad/codegen/uops.py,sha256=yKS-3w9teuS_3BLnHAN4vWtSRvZHmsx194YRBXMOhFI,21872
13
+ tinygrad/engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ tinygrad/engine/graph.py,sha256=eEbb17qbJ0A-2VjN4l7SCbA4yI7jh6YN6EybPShmcbg,5221
15
+ tinygrad/engine/jit.py,sha256=TrdZQEXnF-SowCAyy69iEp85NPFFtOVaIz_2i1cwCvQ,11049
16
+ tinygrad/engine/realize.py,sha256=H2CgiLWRqTQSr9udJb1iLX7DFdLcPLwxSJZ2bUfHXDs,11077
17
+ tinygrad/engine/schedule.py,sha256=I3OxiNwveWbVuhhFWdR1B5G8goEyqXOHWe7UQQ3Ogz8,18487
18
+ tinygrad/engine/search.py,sha256=M11qHlufIffBS9b7mjk8gniyoGRaengwXnMfUQTHmyw,11289
19
+ tinygrad/nn/__init__.py,sha256=DoHrq9pUFs1vm9FUA5eet5_tvhlZJlC_uC4kBqo98kI,12932
20
+ tinygrad/nn/datasets.py,sha256=Mvf_0eCEEqtB9-8iiyDFhm-B7rTyWXwmSY_3xjeHleo,458
21
+ tinygrad/nn/optim.py,sha256=zf85kwumpS17fk1NBjUfe7tOUE7-XH37SL-LjgAfva8,6799
22
+ tinygrad/nn/state.py,sha256=nGR05s3kuDNp9lliCIr4-6Ek7Korha7jCAWie5S2rB4,10138
23
+ tinygrad/renderer/__init__.py,sha256=-LjQ9tC2rI8fveaS_xn24X_knXKILFj-iZFcRTk8fNM,2672
24
+ tinygrad/renderer/assembly.py,sha256=MD-SSC7-Nqwt3zrwe0aDXVX08W9Ox6Vj_byPS1k1bAQ,17923
25
+ tinygrad/renderer/cstyle.py,sha256=tFWWW-egorLFEDwX6fA9-rYxvNLc67LjxlZ6JzrWCF0,24043
26
+ tinygrad/renderer/llvmir.py,sha256=BZViWXj2G6JtEcOgc-CtnIj-d9xP0ZjgNXdTKQT_PJ8,10315
27
+ tinygrad/runtime/ops_amd.py,sha256=3jOrFqxk8JkPX043tEUFLvyKSgX7Fls785g_gOkdzVM,31811
28
+ tinygrad/runtime/ops_clang.py,sha256=XWqwobReRdu-Tj-chbWEJFMx6AQfgdGCcpdWcLWUTOQ,1468
29
+ tinygrad/runtime/ops_cuda.py,sha256=cgeoVpY9bOGU22Eh78XR5YOYY2cgsJt4Vnxl6u8N6co,10840
30
+ tinygrad/runtime/ops_disk.py,sha256=75-iihZxkhNvA5O3VaW61LOXwmlSX4XwegpnV1C4D5A,2738
31
+ tinygrad/runtime/ops_gpu.py,sha256=FB3Fp-VVEDGEt_6CfJxsM_TWzhp5giXCP1TSSRMXE80,7532
32
+ tinygrad/runtime/ops_hsa.py,sha256=YNQLqZjJ9twTJRKS41l2oIrncOAu3wdOdBegs9zYlgo,16188
33
+ tinygrad/runtime/ops_llvm.py,sha256=dODiyVSlPofHyDIZrD-V74h8W1d94VPnr_-A4gNbSO4,2229
34
+ tinygrad/runtime/ops_metal.py,sha256=fGSNpwmYIHaif9a5SiwyMX2bub-r5hTNpnrqlaPMeUc,5815
35
+ tinygrad/runtime/ops_npy.py,sha256=qaAi0AEo6nt7iZ-eWqM8z2aQfNJgZUpmBCEDmrIzWL0,369
36
+ tinygrad/runtime/ops_nv.py,sha256=PCMAHMrW4J7esgnkpwq3bB91Q3h6hBATr8JuykR9vcA,37633
37
+ tinygrad/runtime/ops_python.py,sha256=mmsDj1hJ3BtreAq5dfCuuUGbgoIhCKlVwNqMDmXBISs,10871
38
+ tinygrad/runtime/autogen/amd_gpu.py,sha256=1NDH0ualiZ8OtgTjaYcQ1HjKs_SQ7eUHuJvdrDodvCk,65022
39
+ tinygrad/runtime/autogen/comgr.py,sha256=Z99Y6K8D_nuMpOs0qDWiA0MV-RxueV65o2OyPFdcsHE,38563
40
+ tinygrad/runtime/autogen/cuda.py,sha256=GgRl4AfU54JG0G1XJj2dq2FbrUZ8XG_AnFrPAZJpSSg,250380
41
+ tinygrad/runtime/autogen/hip.py,sha256=1yUHDCwL3KkD15if2Q1Ud3GbJiR7DxsNorKZTCINw54,245532
42
+ tinygrad/runtime/autogen/hsa.py,sha256=tGpnXUhhQkAIEr0yyCxRImzajmt-nN0KzJn4KnT_bH8,270073
43
+ tinygrad/runtime/autogen/kfd.py,sha256=dDmLFL6HL_QXRW2rZOCArY55PRuXuLN9563XCipV2jM,29935
44
+ tinygrad/runtime/autogen/nv_gpu.py,sha256=K9WwwdIitHrY2AXpYy8bbdD9aEwdbz9vL7748pz6Re0,1672024
45
+ tinygrad/runtime/autogen/opencl.py,sha256=aW-luGFF5PXFEZ6SgrGANhA9qpkI-fZyEsmDfpez2Ss,82654
46
+ tinygrad/runtime/driver/hip_comgr.py,sha256=rFQRsOYo4XquwcHFTe2mGzMfozdL9hitO3DRYBDFSuM,3376
47
+ tinygrad/runtime/driver/hsa.py,sha256=PoNy8gHBPoRUhUZligFp0z_Le9fyEXbJrnlgwInt_R0,7218
48
+ tinygrad/runtime/graph/clang.py,sha256=10Bs64J0r12g6upqCHVoK3LoTrdbBBHQD43efMhlBjo,1957
49
+ tinygrad/runtime/graph/cuda.py,sha256=LNx6RQLcQSKMlHfVK5r_efujN0lRPhKqi8yp249OAIs,5265
50
+ tinygrad/runtime/graph/hcq.py,sha256=mspwzroBTwNNHDob7oK-JCt48mhuIhX_G0qNYvFVuVM,8089
51
+ tinygrad/runtime/graph/hsa.py,sha256=UJgSg2irrKT87LBZ3DfaGmoK7rJk8OZhIHEHhtF8rUE,10035
52
+ tinygrad/runtime/graph/metal.py,sha256=bwB6uAsqjEbwv5ML5ziWduBtmTpseJboo6J9ssVa4v4,4579
53
+ tinygrad/shape/shapetracker.py,sha256=hWqh2uWsbBp3lKlRpY8Fj1oTWvEx1YwVsKl0QiA-QnU,6334
54
+ tinygrad/shape/symbolic.py,sha256=hn2khLoHAJSwyZ91i679oJZCLTaz0Sf2dUG-HRJMtVw,16688
55
+ tinygrad/shape/view.py,sha256=KMf_KzNwXmcX1NbFPq862-Jv_E6TgeO27lcPjrAweF4,17092
56
+ tinygrad-0.9.0.dist-info/LICENSE,sha256=ABRhUPEILzINYIukgazD-_rPipkUNUwslrb0RxnV6Xc,1058
57
+ tinygrad-0.9.0.dist-info/METADATA,sha256=oyGO3WSmMQ7NTAK3RGk0ZXCkr-L3XKltKYhYrKEuifk,10227
58
+ tinygrad-0.9.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
59
+ tinygrad-0.9.0.dist-info/top_level.txt,sha256=vDABMCWBFQnx2kn9Azueu88FP-1klQdePoHikQhHymc,9
60
+ tinygrad-0.9.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.42.0)
2
+ Generator: bdist_wheel (0.43.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,93 +0,0 @@
1
- from typing import Tuple
2
- from tinygrad.helpers import prod, IMAGE, getenv, DEBUG
3
- from tinygrad.dtype import dtypes
4
-
5
- # *** image Tensor function replacements ***
6
-
7
- def image_dot(self, w):
8
- # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
9
- n1, n2 = len(self.shape), len(w.shape)
10
- assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
11
- assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
12
- bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
13
- out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
14
-
15
- # NOTE: with NHWC we can remove the transposes
16
- # bs x groups*cin x H x W
17
- cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
18
- # groups*cout x cin x H, W
19
- cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
20
- return image_conv2d(cx, cw, groups=groups).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
21
-
22
- def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
23
- base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
24
-
25
- (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
26
- x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
27
-
28
- # hack for non multiples of 4 on cin
29
- if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
30
- x = x.reshape(bs, groups, cin, iy, ix) # do this always?
31
- added_input_channels = 4 - (cin % 4)
32
- w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
33
- x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
34
- cin = cin + added_input_channels
35
- x = x.reshape(bs, groups*cin, iy, ix)
36
-
37
- # hack for non multiples of 4 on rcout
38
- added_output_channels = 0
39
- if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
40
- added_output_channels = 4 - (rcout % 4)
41
- rcout += added_output_channels
42
- cout = groups * rcout
43
- w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
44
-
45
- # packed (note: flipping bs and iy would make the auto-padding work)
46
- x = x.permute(0,2,3,1)
47
- cin_last = iy == 1 and ix == 1
48
- if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
49
- elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
50
- else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
51
-
52
- # contiguous creates the image, and early realize static weights (TODO: test for the static weight)
53
- if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
54
- x, w = x.contiguous(), w.contiguous()
55
-
56
- # expand out
57
- rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
58
- cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
59
- x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
60
- if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
61
- else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
62
-
63
- # padding
64
- padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
65
- x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
66
-
67
- # prepare input
68
- x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
69
- x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
70
-
71
- # prepare weights
72
- w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
73
-
74
- # the conv!
75
- ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1))
76
-
77
- # undo hack for non multiples of 4 on C.rcout
78
- if added_output_channels != 0:
79
- ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
80
- cout = groups * (rcout - added_output_channels)
81
-
82
- # NCHW output
83
- ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
84
- return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
85
-
86
- # *** images have weird indexing requirements ***
87
-
88
- from tinygrad.shape.symbolic import Node
89
- def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
90
- idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
91
- # TODO: bring back the valid removal logic (correct!)
92
- if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
93
- return (idx, idy), valid
@@ -1,103 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Optional, Union, Any, Tuple, List
3
- import functools
4
- from tinygrad.helpers import all_same, dedup
5
- from tinygrad.dtype import DType
6
- from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
7
- from tinygrad.lazy import LazyBuffer, create_schedule
8
- from tinygrad.shape.shapetracker import ShapeTracker, sint
9
-
10
- def all_reduce(lbs):
11
- # TODO: replace this with ring reduce
12
- return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
13
-
14
- def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
15
- assert lbs[0].shape[axis] % len(lbs) == 0, f"{lbs[0].shape=} {axis=} {len(lbs)=}"
16
- sz = lbs[0].shape[axis] // len(lbs)
17
- return [lb.shrink(tuple((0,s) if a != axis else (sz*i,sz*(i+1)) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
18
-
19
- class MultiLazyBuffer:
20
- def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]):
21
- assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs) >= 2, "all lbs must be LazyBuffers, and we need at least two of them"
22
- assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st"
23
- self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs)
24
- self.shape = tuple(s*len(self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape))
25
-
26
- def __repr__(self):
27
- return f"<MLB{chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
28
-
29
- @staticmethod
30
- def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
31
- lbs = [lb.contiguous() if lb.base != lb else lb] * len(devices)
32
- return MultiLazyBuffer([lb.copy_to_device(d).contiguous() for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis)
33
-
34
- def copy_to_device(self, device:str) -> LazyBuffer:
35
- if self.axis is None: return self.lbs[0].copy_to_device(device)
36
- sz = self.lbs[0].shape[self.axis]
37
- llbs = []
38
- for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]):
39
- pad_arg = tuple((0,0) if a != self.axis else (sz*i,(s*len(self.lbs))-sz*(i+1)) for a,s in enumerate(lb.shape))
40
- llbs.append(lb.pad(pad_arg))
41
- return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
42
-
43
- # TODO: fix this
44
- def is_unrealized_contiguous_const(self): return False
45
-
46
- # passthroughs
47
- def schedule(self, seen=None): return create_schedule(self.lbs, seen)
48
- def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis)
49
- def const(self, val:Union[float, int]) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis)
50
- def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis)
51
-
52
- # elementwise is simple
53
- def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
54
- msrcs = (self,)+in_srcs
55
- assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
56
- assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
57
-
58
- # NOTE: they all have to share an axis, we always choose [-1]
59
- axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
60
- srcs = []
61
- for mlb in msrcs:
62
- if mlb.axis == axis: srcs.append(mlb.lbs)
63
- elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
64
- else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
65
- return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis)
66
-
67
- def _shape_to_single_shard(self, shape): return tuple(s//len(self.lbs) if a == self.axis else s for a,s in enumerate(shape))
68
-
69
- def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer:
70
- if self.axis is not None and new_shape[self.axis] == 1:
71
- # all-reduce on sharded axes
72
- return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None)
73
- # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
74
- return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape)) for x in self.lbs], self.axis)
75
-
76
- # *** movement ops ***
77
-
78
- def reshape(self, arg:Tuple[sint, ...]):
79
- if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None)
80
- # TODO: this can be wrong
81
- st = ShapeTracker.from_shape(self.shape)
82
- rs = st.real_strides()[self.axis]
83
- new_axis = st.reshape(arg).real_strides().index(rs)
84
- narg = tuple(s//len(self.lbs) if a == new_axis else s for a,s in enumerate(arg))
85
- return MultiLazyBuffer([x.reshape(narg) for x in self.lbs], new_axis)
86
-
87
- def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
88
- assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis"
89
- return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis)
90
- def expand(self, arg:Tuple[sint, ...]):
91
- # NOTE: this assert isn't needed, sharded axis can have dim 1
92
- assert self.axis is None or arg[self.axis] == self.lbs[0].shape[self.axis] * len(self.lbs), "expand not supported on sharded axis"
93
- return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg)) for x in self.lbs], self.axis)
94
- def permute(self, arg:Tuple[int, ...]):
95
- # all permutes supported!
96
- return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None)
97
- def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
98
- assert self.axis is None or arg[self.axis] == (0, self.lbs[0].shape[self.axis] * len(self.lbs)), "shrinking not supported on sharded axis"
99
- narg = tuple((s1//len(self.lbs), s2//len(self.lbs)) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))
100
- return MultiLazyBuffer([x.shrink(narg) for x in self.lbs], self.axis)
101
- def stride(self, arg:Tuple[int, ...]):
102
- assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
103
- return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis)
@@ -1,160 +0,0 @@
1
- from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
2
- import itertools, random, math, time, multiprocessing, traceback, signal
3
- from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner
4
- from tinygrad.ops import MemBuffer, LazyOp
5
- from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
6
- from tinygrad.dtype import ImageDType
7
- from tinygrad.codegen.linearizer import Linearizer
8
- from collections import defaultdict
9
- from tinygrad.tensor import Tensor
10
- from tinygrad.shape.symbolic import sym_infer
11
-
12
- from tinygrad.codegen.kernel import Opt, OptOps
13
- actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7] for axis in range(6)]
14
- actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4] for axis in range(4)]
15
- actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
16
- actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256] for axis in range(3)]
17
- actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
18
- actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
19
- Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),]
20
- if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
21
-
22
- def _get_test_global_size(global_size, max_global_size, var_vals):
23
- test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
24
- while prod(test_global_size) > max_global_size:
25
- for j in range(len(global_size)-1,-1,-1):
26
- if test_global_size[j] > 16:
27
- test_global_size[j] //= 2
28
- factor *= 2
29
- break
30
- return test_global_size, factor
31
-
32
- def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): # noqa: E501
33
- factor = 1
34
- if global_size is not None and max_global_size is not None:
35
- global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
36
- car = CompiledASTRunner(ast, name, "", lib, global_size, local_size).build(rdev.runtime)
37
- tms = []
38
- for _ in range(cnt):
39
- if clear_l2:
40
- with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
41
- tms.append(car(rawbufs, var_vals, wait=True, do_update_stats=False)*factor)
42
- if early_stop is not None and early_stop < tms[-1]: break
43
- return tms
44
-
45
- def _compile_linearizer(rdev:Compiled, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]]]:
46
- lin.linearize()
47
- src = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
48
- return rdev.compiler(src), lin.global_size, lin.local_size
49
-
50
- def _try_compile_linearized_w_idx(x):
51
- try: return (x[0], _compile_linearizer(cast(Compiled, Device[Device.DEFAULT]), x[1], "test"))
52
- except Exception:
53
- if DEBUG >= 4: traceback.print_exc()
54
- return (x[0], None)
55
-
56
- # workers should ignore ctrl c
57
- def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
58
-
59
- # *** external API ***
60
-
61
- # get (scrap) buffers for timing the linearizer
62
- def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
63
- bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
64
- for x in lin.membufs: bufsts[x.idx].append(x)
65
- rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
66
- for k,lx in bufsts.items():
67
- buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
68
- rawbufs[k] = Buffer(Device.DEFAULT, buf_size, lx[0].dtype)
69
- assert all(r is not None for r in rawbufs)
70
- return cast(List[Buffer], rawbufs)
71
-
72
- # get dictionary of all possible actions
73
- def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
74
- acted_lins = {0:lin} if include_0 else {}
75
- for i,a in enumerate(actions):
76
- if a.axis is not None and a.axis >= lin.shape_len: continue
77
- if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
78
- lin2 = lin.copy()
79
- try:
80
- lin2.apply_opt(a)
81
- up, lcl = 1, 1
82
- for s,c in zip(lin2.full_shape, lin2.colors()):
83
- if c in {"magenta", "yellow"}: up *= s
84
- if c in {"cyan", "green", "white"}: lcl *= s
85
- if up > 256 or lcl > 256: continue
86
- acted_lins[i+1] = lin2
87
- except Exception:
88
- pass
89
- return acted_lins
90
-
91
- def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
92
- key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
93
- if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
94
- ret = lin.copy()
95
- for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
96
- return ret
97
-
98
- beam: List[Tuple[Linearizer, float]] = []
99
- seen_libs = set()
100
-
101
- default_parallel = 1 if Device.DEFAULT in {"CUDA", "HIP"} else 0
102
- pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker) if getenv("PARALLEL", default_parallel) else None
103
-
104
- try:
105
- var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
106
- exiting, st = False, time.perf_counter()
107
- dev = Device[Device.DEFAULT]
108
- assert isinstance(dev, Compiled)
109
- while not exiting:
110
- acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
111
- timed_lins: List[Tuple[Linearizer, float]] = []
112
- for i,proc in (pool.imap_unordered(_try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(_try_compile_linearized_w_idx, enumerate(acted_lins))): # noqa: E501
113
- if proc is None: continue
114
- lib, global_size, local_size = proc
115
- if lib in seen_libs: continue
116
- seen_libs.add(lib)
117
- tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
118
- timed_lins.append((acted_lins[i], min(tms)))
119
- if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
120
-
121
- # done
122
- opts = sorted(timed_lins, key=lambda x: x[1])
123
- exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
124
- if not exiting: beam = opts[:amt]
125
- assert len(beam) > 0, "no BEAM items succeeded?!?"
126
- if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
127
- if pool is not None: pool.close() # the pool is closed
128
- except KeyboardInterrupt as e:
129
- if pool is not None: pool.terminate()
130
- raise e
131
-
132
- if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
133
- if DEBUG >= 3: print(beam[0][0].applied_opts)
134
- return beam[0][0]
135
-
136
- def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
137
- test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
138
- MAX_WORKGROUP = 1024
139
- local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
140
- local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
141
- def try_exec(local_size):
142
- try: return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
143
- except Exception: return float('inf')
144
- ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
145
- assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
146
- return ret[1]
147
-
148
- def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
149
- key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501
150
- if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
151
-
152
- dev = Device[Device.DEFAULT]
153
- assert isinstance(dev, Compiled)
154
-
155
- var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
156
- lib, global_size, local_size = _compile_linearizer(dev, lin)
157
- tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
158
-
159
- if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
160
- return min(tms)
tinygrad/graph.py DELETED
@@ -1,106 +0,0 @@
1
- import os, atexit
2
- from collections import defaultdict
3
- from typing import List, Any, DefaultDict
4
- from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp, GlobalCounters
5
- from tinygrad.device import Device
6
- from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, getenv
7
- from tinygrad.codegen.linearizer import UOps, UOp
8
- from tinygrad.shape.symbolic import NumNode
9
-
10
- # **** debugging and graphing ****
11
-
12
- if DEBUG >= 2:
13
- def print_globalcounters():
14
- if GlobalCounters.time_sum_s == 0: return
15
- print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
16
- f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
17
- atexit.register(print_globalcounters)
18
-
19
- G:Any = None
20
- def init_graph():
21
- global G
22
- if G is not None: return
23
- import networkx as nx
24
- G = nx.DiGraph()
25
- def save_graph_exit():
26
- print("saving", G, f"to {GRAPHPATH}.svg")
27
- nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
28
- # -Gnslimit=100 can make it finish, but you won't like results
29
- os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
30
- atexit.register(save_graph_exit)
31
-
32
- counts: DefaultDict[type, int] = defaultdict(int)
33
- def nm(x):
34
- if not hasattr(x, 'node_id'):
35
- setattr(x, 'node_id', counts[type(x)])
36
- counts[type(x)] += 1
37
- return x.node_id
38
-
39
- def get_sop(op: List[Op]):
40
- op = [x for x in op if x not in BufferOps]
41
- if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
42
- if len(op) <= 6: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
43
- return str(len(op))
44
-
45
- def str_dtype(dtyp):
46
- ret = str(dtyp)[7:]
47
- return "" if ret == 'float' else f"\n{ret}"
48
-
49
- def realized_lazybuffer(lb, num):
50
- if GRAPH:
51
- init_graph()
52
- G.nodes[nm(lb)]['style'] = '"filled,bold"'
53
- G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
54
- G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num} b:{nm(lb.realized)}"'
55
-
56
- def log_lazybuffer(lb, scheduled=False):
57
- top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
58
- MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
59
- if GRAPH:
60
- init_graph()
61
- if lb.base != lb:
62
- offset = lb.st.expr_node(NumNode(0))[0]
63
- label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
64
- G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
65
- G.add_edge(nm(lb.base), nm(lb), color='#00000060')
66
- lb = lb.base
67
- if lb.realized is None:
68
- for x in lb.srcs:
69
- if nm(x) not in G.nodes: log_lazybuffer(x)
70
- G.add_edge(nm(x), nm(lb), color='#a0a0a0')
71
- label = '"' + \
72
- (str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
73
- str_dtype(lb.dtype)+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
74
- (f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + '"'
75
- G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
76
- if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
77
- else:
78
- if nm(lb) not in G.nodes:
79
- # realized but unseen?
80
- G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
81
-
82
- def _tree(lazydata, cycles, cnt, prefix=""):
83
- cnt[0] += 1
84
- if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
85
- if (lid := id(lazydata)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
86
- return [f"━⬆︎ goto {cycles[id(lazydata)][0]}: {lazydata.op.name}"]
87
- cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
88
- lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
89
- childs = [_tree(c, cycles, cnt) for c in lazydata.src[:]]
90
- for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
91
- return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
92
-
93
- def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata, {}, [-1]))]))
94
-
95
- def graph_uops(uops:List[UOp]):
96
- import networkx as nx
97
- colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
98
- UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
99
- UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
100
- G = nx.DiGraph()
101
- for u in uops:
102
- if u.uop == UOps.END: continue
103
- G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
104
- for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
105
- nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.uops.dot')
106
- os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.uops.dot -o {GRAPHPATH}.uops.svg')