tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: tinygrad
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.1
|
4
4
|
Summary: You like pytorch? You like micrograd? You love tinygrad! <3
|
5
5
|
Author: George Hotz
|
6
6
|
License: MIT
|
@@ -11,26 +11,19 @@ Description-Content-Type: text/markdown
|
|
11
11
|
License-File: LICENSE
|
12
12
|
Provides-Extra: arm
|
13
13
|
Requires-Dist: unicorn; extra == "arm"
|
14
|
-
Provides-Extra:
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist: mkdocs-material; extra == "docs"
|
17
|
-
Requires-Dist: mkdocstrings[python]; extra == "docs"
|
18
|
-
Requires-Dist: markdown-callouts; extra == "docs"
|
19
|
-
Requires-Dist: markdown-exec[ansi]; extra == "docs"
|
20
|
-
Requires-Dist: black; extra == "docs"
|
21
|
-
Requires-Dist: numpy; extra == "docs"
|
14
|
+
Provides-Extra: triton
|
15
|
+
Requires-Dist: triton-nightly>=2.1.0.dev20231014192330; extra == "triton"
|
22
16
|
Provides-Extra: linting
|
23
17
|
Requires-Dist: pylint; extra == "linting"
|
24
|
-
Requires-Dist: mypy==1.
|
18
|
+
Requires-Dist: mypy==1.13.0; extra == "linting"
|
25
19
|
Requires-Dist: typing-extensions; extra == "linting"
|
26
20
|
Requires-Dist: pre-commit; extra == "linting"
|
27
21
|
Requires-Dist: ruff; extra == "linting"
|
28
22
|
Requires-Dist: types-tqdm; extra == "linting"
|
29
|
-
Provides-Extra: llvm
|
30
|
-
Requires-Dist: llvmlite; extra == "llvm"
|
31
23
|
Provides-Extra: testing
|
32
24
|
Requires-Dist: numpy; extra == "testing"
|
33
25
|
Requires-Dist: torch; extra == "testing"
|
26
|
+
Requires-Dist: jax; extra == "testing"
|
34
27
|
Requires-Dist: pillow; extra == "testing"
|
35
28
|
Requires-Dist: pytest; extra == "testing"
|
36
29
|
Requires-Dist: pytest-xdist; extra == "testing"
|
@@ -50,11 +43,28 @@ Requires-Dist: hypothesis; extra == "testing"
|
|
50
43
|
Requires-Dist: nibabel; extra == "testing"
|
51
44
|
Requires-Dist: bottle; extra == "testing"
|
52
45
|
Requires-Dist: ggml-python; extra == "testing"
|
53
|
-
|
46
|
+
Requires-Dist: capstone; extra == "testing"
|
47
|
+
Provides-Extra: webgpu
|
48
|
+
Requires-Dist: wgpu; extra == "webgpu"
|
49
|
+
Provides-Extra: docs
|
50
|
+
Requires-Dist: mkdocs; extra == "docs"
|
51
|
+
Requires-Dist: mkdocs-material; extra == "docs"
|
52
|
+
Requires-Dist: mkdocstrings[python]; extra == "docs"
|
53
|
+
Requires-Dist: markdown-callouts; extra == "docs"
|
54
|
+
Requires-Dist: markdown-exec[ansi]; extra == "docs"
|
55
|
+
Requires-Dist: black; extra == "docs"
|
56
|
+
Requires-Dist: numpy; extra == "docs"
|
57
|
+
Provides-Extra: testing-tf
|
54
58
|
Requires-Dist: tensorflow==2.15.1; extra == "testing-tf"
|
55
|
-
Requires-Dist:
|
56
|
-
|
57
|
-
|
59
|
+
Requires-Dist: tensorflow_addons; extra == "testing-tf"
|
60
|
+
Dynamic: author
|
61
|
+
Dynamic: classifier
|
62
|
+
Dynamic: description
|
63
|
+
Dynamic: description-content-type
|
64
|
+
Dynamic: license
|
65
|
+
Dynamic: provides-extra
|
66
|
+
Dynamic: requires-python
|
67
|
+
Dynamic: summary
|
58
68
|
|
59
69
|
<div align="center">
|
60
70
|
|
@@ -146,6 +156,7 @@ tinygrad already supports numerous accelerators, including:
|
|
146
156
|
- [x] [AMD](tinygrad/runtime/ops_amd.py)
|
147
157
|
- [x] [NV](tinygrad/runtime/ops_nv.py)
|
148
158
|
- [x] [QCOM](tinygrad/runtime/ops_qcom.py)
|
159
|
+
- [x] [WEBGPU](tinygrad/runtime/ops_webgpu.py)
|
149
160
|
|
150
161
|
And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
|
151
162
|
|
@@ -183,8 +194,8 @@ y = Tensor([[2.0,0,-2.0]], requires_grad=True)
|
|
183
194
|
z = y.matmul(x).sum()
|
184
195
|
z.backward()
|
185
196
|
|
186
|
-
print(x.grad.
|
187
|
-
print(y.grad.
|
197
|
+
print(x.grad.tolist()) # dz/dx
|
198
|
+
print(y.grad.tolist()) # dz/dy
|
188
199
|
```
|
189
200
|
|
190
201
|
The same thing but in PyTorch:
|
@@ -196,8 +207,8 @@ y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
|
|
196
207
|
z = y.matmul(x).sum()
|
197
208
|
z.backward()
|
198
209
|
|
199
|
-
print(x.grad.
|
200
|
-
print(y.grad.
|
210
|
+
print(x.grad.tolist()) # dz/dx
|
211
|
+
print(y.grad.tolist()) # dz/dy
|
201
212
|
```
|
202
213
|
|
203
214
|
## Contributing
|
@@ -0,0 +1,86 @@
|
|
1
|
+
tinygrad/__init__.py,sha256=2Jhg7NSWlegCi4OAfGW0iHRVHeqMx09f7446rwAmc60,587
|
2
|
+
tinygrad/device.py,sha256=mUrxoZJfqBJepDeqbmS2Y8UbOB_UPERqP9zBN7sCBk8,18968
|
3
|
+
tinygrad/dtype.py,sha256=010zGuqXwUoyrhe23nJQ5oHmJt_6CRL7Hcl0kWjvKbo,9843
|
4
|
+
tinygrad/gradient.py,sha256=hVyRMnwzjtWwKfF0NMGQovE2I7_2GdKFyR9pWJn4eE4,4280
|
5
|
+
tinygrad/helpers.py,sha256=DIofGx-mg-umNKjhNpINcrxZExUwo6JIKHgmFuKNLUM,19203
|
6
|
+
tinygrad/ops.py,sha256=T1gmR1ywWPnGEcw-MPAW5SGEvNS_789PtnFrR8rz5H4,70614
|
7
|
+
tinygrad/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
+
tinygrad/spec.py,sha256=Wi8PdgeIHzv6WCTwKT9zGP5_53Bt_QtmvGTjIzTuwi0,8953
|
9
|
+
tinygrad/tensor.py,sha256=-DVucOwq8HmM7IHqW7ZNABy7-vYZI-YfF0Vzdo0rs70,181559
|
10
|
+
tinygrad/codegen/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
tinygrad/codegen/kernel.py,sha256=4TcuMvvGODd0zFBuhK9KzvQhwzknRgeCL_n9WrOpljQ,42432
|
12
|
+
tinygrad/codegen/linearize.py,sha256=X9aOMhvimDTvMq2Vnj3zfGNIRhJnAG5mLcq6EeUfvvU,10382
|
13
|
+
tinygrad/codegen/lowerer.py,sha256=DqFg53oALqWG7RygOjXEioyGk52BnGbCzvQSqiOIKGw,7527
|
14
|
+
tinygrad/codegen/rewriter.py,sha256=T3tOn-T1IFqiP90GCqfmkr0V5i0knr2tjQfKPVr6zzM,29948
|
15
|
+
tinygrad/codegen/transcendental.py,sha256=0qRaEtIoJKDfjPqvQWexShZW3F__wtzfjhU__BqiMD8,13112
|
16
|
+
tinygrad/engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
tinygrad/engine/jit.py,sha256=6D8pCpRPZNANwFEHAsAMUgUaEwaCf4mB_1oem6oZZ7o,16630
|
18
|
+
tinygrad/engine/memory.py,sha256=UyiNYIoUjtcJ1SX6ApoCnbrSKZpbBbhfwr13TIaiqEM,3231
|
19
|
+
tinygrad/engine/multi.py,sha256=FnyCvwFTdBXiydXB0dD5Y63KnerCurZS_o8WOU5fHFM,10371
|
20
|
+
tinygrad/engine/realize.py,sha256=vOQ33sTn_P2oT0GRFU5tKYLfiRmxrIIdVs27v0_a40g,9690
|
21
|
+
tinygrad/engine/schedule.py,sha256=9JywWlmVo912JmbIF9r5ajr5-Sesqw1EhE-CYnk5LHI,27407
|
22
|
+
tinygrad/engine/search.py,sha256=VOsUOPaI7zWtDgL2SJDiKSPaTs3MhjgZOJBZ5_ISeiY,12093
|
23
|
+
tinygrad/nn/__init__.py,sha256=BAxMz-g7v-v1A32KaBzmGiaEnvQ_dU9d3KoPdYYwLDQ,15156
|
24
|
+
tinygrad/nn/datasets.py,sha256=wcT0Qrlpw_RzM7uBy8uphzKAjGT-ZE48fiP-0g3WvI0,1042
|
25
|
+
tinygrad/nn/optim.py,sha256=qfdYKi_ssX5O_DU6h8GJ0WCzBzAZLyyS3p_946PJNsQ,6816
|
26
|
+
tinygrad/nn/state.py,sha256=zXFMwAw7sf35C9x3iAxYiPsmhk7_S6qPcX3wXxqp6Bw,16030
|
27
|
+
tinygrad/renderer/__init__.py,sha256=KvZ3y7MnqKDHOKoTcxghT95FhTBJuzs3IFEBM077jw8,6989
|
28
|
+
tinygrad/renderer/cstyle.py,sha256=TrmxtioR7Lgw_8CK-Ay9EA0HjS1TXRf-RaMsCsuHszw,29852
|
29
|
+
tinygrad/renderer/llvmir.py,sha256=iHFjE9-GtH2sGD_feCGe7aGECPHJ5qGOyuuyCDxSEXU,8529
|
30
|
+
tinygrad/renderer/ptx.py,sha256=is4PMkdbDDgKsTy67DtkugatdhbtVMijKE9u1f6-0ag,14899
|
31
|
+
tinygrad/renderer/wgsl.py,sha256=3AGZp2jvymxLZSJHjzD7zlwxvds36-lSu2ZkntrW_ww,6697
|
32
|
+
tinygrad/runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
|
+
tinygrad/runtime/ops_amd.py,sha256=zRictmcgUPw6d8iykz6QWqCmiPirOh-hTACT-GYxso8,37282
|
34
|
+
tinygrad/runtime/ops_clang.py,sha256=InJPVVhY10bxRXOaDQRUBlWafSZRdo72Y0VVW9O6dBw,1310
|
35
|
+
tinygrad/runtime/ops_cloud.py,sha256=1umLappzYKg0ECC9C_EN1lzXquAQVqY7wj-mS0vTcSk,10362
|
36
|
+
tinygrad/runtime/ops_cuda.py,sha256=26ko2bTLT_jDmU2i8_xiL4BomD9krzvlz6wY5FE-e5c,7138
|
37
|
+
tinygrad/runtime/ops_disk.py,sha256=3eTpJSVQQH32KBa7tRLBm2pPSkCaIC1adPc748UkUeg,6619
|
38
|
+
tinygrad/runtime/ops_dsp.py,sha256=JCSehSDLzaa8Nhd6Xtas0d5vOCgwQGY4U3awHOLThEI,16900
|
39
|
+
tinygrad/runtime/ops_gpu.py,sha256=VrY9iM5i44OQwTJE8sgUavCtr1TOpE96I5-JmqUFz-E,8972
|
40
|
+
tinygrad/runtime/ops_hip.py,sha256=MbR4depxgHcaGpOJGvUCiLq5tdbpaDiqs-Xj41rY2xQ,3730
|
41
|
+
tinygrad/runtime/ops_llvm.py,sha256=yjW5wVMJdmayeeTQwnQ60mtOWiiYEqvtb9Rt5cDj2rU,3239
|
42
|
+
tinygrad/runtime/ops_metal.py,sha256=rz3fiaiB1by_RuP7VFVBgGA1mye43_-6PIfMfeQ3YmU,13229
|
43
|
+
tinygrad/runtime/ops_npy.py,sha256=8VNf1S5M_MRk9d3GxSsTPbfEz7I_aOwl7QMZ1mUG3As,370
|
44
|
+
tinygrad/runtime/ops_nv.py,sha256=8NvxtEc89dbJ5xgVhGs9_zeGMztp9HRZHLwPqylDsdQ,34541
|
45
|
+
tinygrad/runtime/ops_python.py,sha256=5CZek-7fMhq6w07lGDM0pXltZzzMMRtuU2x5PeWMyWY,11681
|
46
|
+
tinygrad/runtime/ops_qcom.py,sha256=Dt4hgAd6o13CxsOFRSt7lHY6bCOTLvtQpOr_jx_lYbc,22565
|
47
|
+
tinygrad/runtime/ops_webgpu.py,sha256=-uBAWhVZ7T_9zG2v6PWjkJLpovrClqb_XAPk-l83ryc,4345
|
48
|
+
tinygrad/runtime/autogen/adreno.py,sha256=u7VxIomPAlW3nFUs4gSTe-6ijam_ywkvDM9OuTLF-j8,897915
|
49
|
+
tinygrad/runtime/autogen/amd_gpu.py,sha256=Iasq-zYiv8bvT43dtvPO1W5jaLEQ3d6hP0CoFVhSsak,3977783
|
50
|
+
tinygrad/runtime/autogen/comgr.py,sha256=3pp3XyqEJDBLa9XtGx2-Gc1iJgBbbgIq4pdFEpYXT44,39874
|
51
|
+
tinygrad/runtime/autogen/cuda.py,sha256=N0QyaMvQumr_HZh7fusCHM1d4o4mYti3Wq1MN7JSKr8,243920
|
52
|
+
tinygrad/runtime/autogen/hip.py,sha256=1yUHDCwL3KkD15if2Q1Ud3GbJiR7DxsNorKZTCINw54,245532
|
53
|
+
tinygrad/runtime/autogen/hsa.py,sha256=7Hsrn17HmChyeFOSX_3Fnzl9c0COtq2Z2ExqGu5FNiU,277716
|
54
|
+
tinygrad/runtime/autogen/io_uring.py,sha256=ZIZ2YnQkLr8WIHMieBw9Dv-NZ1ar9TwdP4YBv3gJm28,59786
|
55
|
+
tinygrad/runtime/autogen/kfd.py,sha256=VdhuG4qec0EgM-jJmWcdTS-8WrmywNkcjSX7ibbmvdk,30866
|
56
|
+
tinygrad/runtime/autogen/kgsl.py,sha256=2EgJ5Kst4oRUv81hsV2srgwPvWpY-weaSB4E2lGMAyc,50656
|
57
|
+
tinygrad/runtime/autogen/libc.py,sha256=xKJk2hCzVpauJSc8wCQis5x3SwcXnDli7_HyRUqEGRc,197318
|
58
|
+
tinygrad/runtime/autogen/libpciaccess.py,sha256=zaVmkoUHVTEQcPQwkFpgMCgX4HtX-BL-MGMbi8XtgCI,84194
|
59
|
+
tinygrad/runtime/autogen/llvm.py,sha256=aeVd_ByohxbGRyqXzShPOupI2xtcdk34I6_OIBrMQHg,467606
|
60
|
+
tinygrad/runtime/autogen/nv_gpu.py,sha256=9X2tPdv2E5JmXGZeT8i9jL19YJ4ETTsYwfU_Wn9mTwc,1679326
|
61
|
+
tinygrad/runtime/autogen/nvrtc.py,sha256=19te2-TW5suFy85KnJox3CPOmeeml5YxqIDeL-Bx_m4,23132
|
62
|
+
tinygrad/runtime/autogen/opencl.py,sha256=NL6fa8P3KC_McNZ8g2babdr3b8vrY-bFK0qzNAtL-rE,82656
|
63
|
+
tinygrad/runtime/autogen/qcom_dsp.py,sha256=jx36-zC6reTuWgfbHCrKVjOZcF4Q9fBnq3CuTbxztQk,61848
|
64
|
+
tinygrad/runtime/autogen/vfio.py,sha256=IJV1eeWWllU6b9LAX_IH0bUW5NDzfhPQy_YzXGhD9-8,32431
|
65
|
+
tinygrad/runtime/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
66
|
+
tinygrad/runtime/graph/cuda.py,sha256=vLjT_c93G6ia_1MsbYWP5Uq96Aeko0AOskRkwT5-MUI,4818
|
67
|
+
tinygrad/runtime/graph/hcq.py,sha256=kKu2YnjAZU40XMACSbbcxJSi2xdTg3OYLO2zcPLyAf0,12600
|
68
|
+
tinygrad/runtime/graph/metal.py,sha256=6JN7WJr9w1pIvQGYr0Rsnyg-SA-slK61jKV__KmHnXg,6001
|
69
|
+
tinygrad/runtime/support/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
70
|
+
tinygrad/runtime/support/allocator.py,sha256=INW6TaJxMi4cwMEDtkta2wFH3kJ87wy3SyFFu3_jJ9w,4721
|
71
|
+
tinygrad/runtime/support/compiler_cuda.py,sha256=6cU1OMMW3aOUFNVALpDYWKXh-zFc5q81PSSQhRK9fLw,5471
|
72
|
+
tinygrad/runtime/support/compiler_hip.py,sha256=fbRP82UdG4T-KCRYH_H2hEXlMFeHIJntSnY35ZWE5JY,4398
|
73
|
+
tinygrad/runtime/support/elf.py,sha256=AxWyaAVEe4xdSTiISqIf80oaHRwUZwb5b1V_Q3874s8,3857
|
74
|
+
tinygrad/runtime/support/hcq.py,sha256=JWLRwoGPZCgRBryNb6uq0qCcfkVfMmNnlnHmCzPFxFI,21901
|
75
|
+
tinygrad/runtime/support/llvm.py,sha256=-qI8NRmSx2dBZnG-OQCTWfmy0XySj_T8kHNRHULTyG4,2174
|
76
|
+
tinygrad/runtime/support/am/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
77
|
+
tinygrad/runtime/support/am/amdev.py,sha256=LD80Bc9DHdl99bxrfwn2c43m1NdDD0mjPnSOJkEfkgU,20433
|
78
|
+
tinygrad/runtime/support/am/ip.py,sha256=WnSIZWSG9IvSjVYLJv-VShW_X6vcziynj__lvTOt4yQ,24531
|
79
|
+
tinygrad/shape/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
80
|
+
tinygrad/shape/shapetracker.py,sha256=clrYjN-zEZHSh_biDy3jnHL4Fq9HyWCdH_CRpwLKki0,7612
|
81
|
+
tinygrad/shape/view.py,sha256=7KJwP2lS1YWhMq80Ka6_h_qavyxSEm9jWloVgHYRx-k,18110
|
82
|
+
tinygrad-0.10.1.dist-info/LICENSE,sha256=ABRhUPEILzINYIukgazD-_rPipkUNUwslrb0RxnV6Xc,1058
|
83
|
+
tinygrad-0.10.1.dist-info/METADATA,sha256=JMb7PpBcZqsf0tbmNJzvBqUrH6U1JER_0mz20UMo0LA,11241
|
84
|
+
tinygrad-0.10.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
85
|
+
tinygrad-0.10.1.dist-info/top_level.txt,sha256=vDABMCWBFQnx2kn9Azueu88FP-1klQdePoHikQhHymc,9
|
86
|
+
tinygrad-0.10.1.dist-info/RECORD,,
|
tinygrad/engine/lazy.py
DELETED
@@ -1,228 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
from typing import Optional, Any, Tuple, List, get_args
|
3
|
-
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
|
4
|
-
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
|
5
|
-
from tinygrad.ops import exec_alu, python_alu
|
6
|
-
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
|
7
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
8
|
-
from tinygrad.device import Buffer
|
9
|
-
from weakref import ref, ReferenceType, WeakValueDictionary
|
10
|
-
|
11
|
-
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
12
|
-
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
13
|
-
base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
|
14
|
-
if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
|
15
|
-
dtype = to_dtype(dtype)
|
16
|
-
if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
|
17
|
-
|
18
|
-
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
19
|
-
if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
|
20
|
-
|
21
|
-
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
|
22
|
-
if enable_cache: lazycache[cache_key] = ret
|
23
|
-
return ret
|
24
|
-
|
25
|
-
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
26
|
-
class LazyBuffer(MathTrait):
|
27
|
-
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
28
|
-
op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
29
|
-
base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
|
30
|
-
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
|
31
|
-
self._base: Optional[LazyBuffer] = None
|
32
|
-
if base is None:
|
33
|
-
# properties on base
|
34
|
-
self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
|
35
|
-
assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
|
36
|
-
|
37
|
-
if self.op is Ops.BUFFER_VIEW:
|
38
|
-
# some LazyBuffers can be processed with only a view, no AST required
|
39
|
-
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
40
|
-
else:
|
41
|
-
self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
|
42
|
-
self.buffer.ref(1)
|
43
|
-
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
44
|
-
self.forced_realize = False
|
45
|
-
else:
|
46
|
-
# properties on view
|
47
|
-
assert base.base == base, "base must be a base itself"
|
48
|
-
self._base = base
|
49
|
-
|
50
|
-
def __del__(self):
|
51
|
-
if hasattr(self, 'buffer'): self.buffer.ref(-1)
|
52
|
-
|
53
|
-
def __repr__(self) -> str:
|
54
|
-
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base is not self else (self.op, self.realized)}>"
|
55
|
-
|
56
|
-
@property
|
57
|
-
def realized(self) -> Optional[Buffer]:
|
58
|
-
# NOTE: we check for a lack of srcs instead of an allocated buffer to make unrealized assigns return None here
|
59
|
-
return self.buffer if self._base is None and not hasattr(self, 'srcs') else None
|
60
|
-
|
61
|
-
# NOTE: this has to be a function to prevent self reference
|
62
|
-
@property
|
63
|
-
def base(self) -> LazyBuffer: return self._base if self._base is not None else self
|
64
|
-
|
65
|
-
# same API as multi
|
66
|
-
@property
|
67
|
-
def lbs(self) -> List[LazyBuffer]: return [self]
|
68
|
-
|
69
|
-
@staticmethod
|
70
|
-
def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
|
71
|
-
assert isinstance(src, tuple)
|
72
|
-
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
73
|
-
|
74
|
-
def const_like(self, b): return self.const_with_shape(b, self.shape)
|
75
|
-
def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer:
|
76
|
-
assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
|
77
|
-
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
78
|
-
|
79
|
-
@property
|
80
|
-
def is_realized(self) -> bool: return self.base.realized is not None
|
81
|
-
|
82
|
-
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
83
|
-
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
84
|
-
assert self.is_realized, f"assign target must be realized {self}"
|
85
|
-
return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
|
86
|
-
src=(self.base, x), enable_cache=True)
|
87
|
-
|
88
|
-
def can_view(self):
|
89
|
-
return (self.st.consecutive and not self.is_unrealized_const() and not isinstance(self.dtype, ImageDType) and
|
90
|
-
self.device.split(":")[0] in view_supported_devices)
|
91
|
-
|
92
|
-
def contiguous(self, allow_buffer_view=True):
|
93
|
-
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
94
|
-
ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS)
|
95
|
-
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
96
|
-
return ret
|
97
|
-
self.base.forced_realize = True
|
98
|
-
return self
|
99
|
-
|
100
|
-
def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True)
|
101
|
-
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
|
102
|
-
if self.dtype == dtype: return self
|
103
|
-
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
104
|
-
if self.is_unrealized_unmasked_const() and not bitcast:
|
105
|
-
return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype))
|
106
|
-
new_shape = self.shape
|
107
|
-
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
108
|
-
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
109
|
-
if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
|
110
|
-
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
111
|
-
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
112
|
-
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
113
|
-
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
114
|
-
# TODO: applying this makes gpt2 slower
|
115
|
-
return self.base.cast(dtype, bitcast)._view(self.st)
|
116
|
-
cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
|
117
|
-
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
118
|
-
|
119
|
-
def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
|
120
|
-
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
121
|
-
|
122
|
-
def _copy(self, device:str) -> LazyBuffer:
|
123
|
-
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
|
124
|
-
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
125
|
-
|
126
|
-
def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
|
127
|
-
# no COPY
|
128
|
-
if self.device == device and not clone: return self
|
129
|
-
|
130
|
-
# double COPY = one COPY
|
131
|
-
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY:
|
132
|
-
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
133
|
-
|
134
|
-
# const doesn't have to be copied (issues with disk tensor)
|
135
|
-
if self.is_unrealized_const():
|
136
|
-
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
137
|
-
|
138
|
-
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
139
|
-
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
|
140
|
-
|
141
|
-
# copy the base and apply the shapetracker on the new device
|
142
|
-
return self.base._copy(device)._view(self.st)
|
143
|
-
|
144
|
-
def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
|
145
|
-
|
146
|
-
def alu(self, op:Ops, *in_srcs:LazyBuffer) -> LazyBuffer:
|
147
|
-
srcs: List[LazyBuffer] = []
|
148
|
-
for s in (self,)+in_srcs:
|
149
|
-
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
150
|
-
srcs.append(root._view(s.base.contiguous_child[1]))
|
151
|
-
else:
|
152
|
-
srcs.append(s)
|
153
|
-
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
|
154
|
-
raise AssertionError(f"all dtypes must match {dts} on {op}")
|
155
|
-
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
156
|
-
if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool"
|
157
|
-
|
158
|
-
out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype
|
159
|
-
|
160
|
-
# const folding
|
161
|
-
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
162
|
-
return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
163
|
-
if op in GroupOp.Binary:
|
164
|
-
x, y = self, in_srcs[0]
|
165
|
-
if op is Ops.ADD:
|
166
|
-
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
167
|
-
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
|
168
|
-
if op is Ops.MUL:
|
169
|
-
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0)
|
170
|
-
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0)
|
171
|
-
if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
|
172
|
-
|
173
|
-
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs))
|
174
|
-
|
175
|
-
# *** reduce ops ***
|
176
|
-
|
177
|
-
def _reduce_op(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
178
|
-
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
179
|
-
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
180
|
-
if len(axis) == 0: return self
|
181
|
-
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, Ops.REDUCE_AXIS, (op, axis), (self,))
|
182
|
-
|
183
|
-
def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
184
|
-
new_shape = self.st.reduce(axis)
|
185
|
-
# TODO: this logic should move to the scheduler
|
186
|
-
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
|
187
|
-
|
188
|
-
# const folding
|
189
|
-
# TODO: fold this for symbolic?
|
190
|
-
if self.is_unrealized_unmasked_const() and all_int(self.shape):
|
191
|
-
if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
|
192
|
-
if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
|
193
|
-
if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
|
194
|
-
|
195
|
-
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
196
|
-
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
|
197
|
-
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
198
|
-
return self._reduce_op(op, axis)
|
199
|
-
|
200
|
-
# if there are few globals, make some reduces into globals by splitting into two kernels
|
201
|
-
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
202
|
-
# ~2**10 should be enough if GROUP is used
|
203
|
-
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
|
204
|
-
# split is moved to the end to provide maximum locality for the second phase reduce.
|
205
|
-
self_real_strides = self.st.real_strides(ignore_valid=True)
|
206
|
-
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
|
207
|
-
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
|
208
|
-
if not split_candidates: return self._reduce_op(op, axis)
|
209
|
-
dim_to_split, divisor = split_candidates[0]
|
210
|
-
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
|
211
|
-
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
|
212
|
-
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
|
213
|
-
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
214
|
-
|
215
|
-
# *** movement ops ***
|
216
|
-
|
217
|
-
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
|
218
|
-
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
219
|
-
return self.const_with_shape(0, new_st.shape)
|
220
|
-
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
221
|
-
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
|
222
|
-
|
223
|
-
def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
|
224
|
-
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
|
225
|
-
def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
|
226
|
-
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
|
227
|
-
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
|
228
|
-
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
|
tinygrad/function.py
DELETED
@@ -1,212 +0,0 @@
|
|
1
|
-
"""This is where the forwards and backwards passes live."""
|
2
|
-
import math
|
3
|
-
from typing import Tuple, Optional
|
4
|
-
from tinygrad.helpers import argsort
|
5
|
-
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
6
|
-
from tinygrad.ops import Ops, resolve, sint
|
7
|
-
from tinygrad.tensor import Function
|
8
|
-
from tinygrad.engine.lazy import LazyBuffer
|
9
|
-
|
10
|
-
class Contiguous(Function):
|
11
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
12
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
|
13
|
-
|
14
|
-
class ContiguousBackward(Function):
|
15
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer: return x
|
16
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
|
17
|
-
|
18
|
-
class Cast(Function):
|
19
|
-
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
|
20
|
-
self.input_dtype, self.bitcast = x.dtype, bitcast
|
21
|
-
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
|
22
|
-
|
23
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
24
|
-
if self.bitcast: raise RuntimeError("bitcast cannot backward")
|
25
|
-
return grad_output.cast(self.input_dtype)
|
26
|
-
|
27
|
-
# ************* unary ops *************
|
28
|
-
|
29
|
-
class Reciprocal(Function):
|
30
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
31
|
-
self.ret = x.reciprocal()
|
32
|
-
return self.ret
|
33
|
-
|
34
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return -grad_output * self.ret * self.ret
|
35
|
-
|
36
|
-
class Sin(Function):
|
37
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
38
|
-
self.x = x
|
39
|
-
return x.sin()
|
40
|
-
|
41
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return (math.pi/2 - self.x).sin() * grad_output
|
42
|
-
|
43
|
-
class Relu(Function):
|
44
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
45
|
-
self.ret = x.maximum(0)
|
46
|
-
return self.ret
|
47
|
-
|
48
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.gt(0).cast(grad_output.dtype) * grad_output
|
49
|
-
|
50
|
-
class Log(Function):
|
51
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
52
|
-
self.x = x
|
53
|
-
return x.log2() * math.log(2)
|
54
|
-
|
55
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / self.x
|
56
|
-
|
57
|
-
class Exp(Function):
|
58
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
59
|
-
self.ret = (x * (1/math.log(2))).exp2()
|
60
|
-
return self.ret
|
61
|
-
|
62
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret * grad_output
|
63
|
-
|
64
|
-
class Sqrt(Function):
|
65
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
66
|
-
self.ret = x.sqrt()
|
67
|
-
return self.ret
|
68
|
-
|
69
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / (self.ret*2)
|
70
|
-
|
71
|
-
# NOTE: the implicit derivative of sigmoid is not stable
|
72
|
-
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
73
|
-
# TODO: have the backend automatically find this
|
74
|
-
class Sigmoid(Function):
|
75
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
76
|
-
self.ret = (1 + (x * (-1/math.log(2))).exp2()).reciprocal()
|
77
|
-
return self.ret
|
78
|
-
|
79
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
80
|
-
return (self.ret * (1 - self.ret)) * grad_output
|
81
|
-
|
82
|
-
class Sign(Function):
|
83
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
|
84
|
-
# backward always return 0 to match torch
|
85
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const_like(0)
|
86
|
-
|
87
|
-
# ************* binary ops *************
|
88
|
-
|
89
|
-
class Less(Function):
|
90
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.lt(y)
|
91
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
92
|
-
|
93
|
-
class Neq(Function):
|
94
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.ne(y)
|
95
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
96
|
-
|
97
|
-
class Xor(Function):
|
98
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x^y
|
99
|
-
|
100
|
-
class BitwiseAnd(Function):
|
101
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x&y
|
102
|
-
|
103
|
-
class BitwiseOr(Function):
|
104
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x|y
|
105
|
-
|
106
|
-
class Threefry(Function):
|
107
|
-
def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.threefry(seed)
|
108
|
-
|
109
|
-
class Add(Function):
|
110
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x+y
|
111
|
-
|
112
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
113
|
-
return grad_output if self.needs_input_grad[0] else None, \
|
114
|
-
grad_output if self.needs_input_grad[1] else None
|
115
|
-
|
116
|
-
class Mul(Function):
|
117
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
118
|
-
self.x, self.y = x, y
|
119
|
-
return x * y
|
120
|
-
|
121
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
122
|
-
return (self.y * grad_output) if self.needs_input_grad[0] else None, \
|
123
|
-
(self.x * grad_output) if self.needs_input_grad[1] else None
|
124
|
-
|
125
|
-
class IDiv(Function):
|
126
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x // y
|
127
|
-
|
128
|
-
# ************* ternary ops *************
|
129
|
-
|
130
|
-
class Where(Function):
|
131
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
|
132
|
-
self.x = x
|
133
|
-
return self.x.where(y, z)
|
134
|
-
|
135
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
|
136
|
-
return None, \
|
137
|
-
self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
|
138
|
-
self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
|
139
|
-
|
140
|
-
# ************* reduce ops *************
|
141
|
-
|
142
|
-
class Sum(Function):
|
143
|
-
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
144
|
-
self.input_shape = x.shape
|
145
|
-
return x.r(Ops.ADD, axis)
|
146
|
-
|
147
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
148
|
-
|
149
|
-
class Prod(Function):
|
150
|
-
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
151
|
-
self.x, self.ret = x, x.r(Ops.MUL, axis)
|
152
|
-
return self.ret
|
153
|
-
|
154
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
155
|
-
return (grad_output * self.ret).expand(self.x.shape) / self.x
|
156
|
-
|
157
|
-
class Max(Function):
|
158
|
-
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
159
|
-
self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
|
160
|
-
return self.ret
|
161
|
-
|
162
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
163
|
-
# 1s in locations where the max was chosen (can be two locations)
|
164
|
-
max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
|
165
|
-
div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
|
166
|
-
return (max_is_1s/div) * grad_output.expand(self.x.shape)
|
167
|
-
|
168
|
-
# ************* movement ops *************
|
169
|
-
|
170
|
-
# NOTE: this is sum in reverse
|
171
|
-
class Expand(Function):
|
172
|
-
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
173
|
-
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
|
174
|
-
return x.expand(shape)
|
175
|
-
|
176
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
177
|
-
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
|
178
|
-
|
179
|
-
class Reshape(Function):
|
180
|
-
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
181
|
-
self.input_shape = x.shape
|
182
|
-
return x.reshape(shape)
|
183
|
-
|
184
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
|
185
|
-
|
186
|
-
class Permute(Function):
|
187
|
-
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
|
188
|
-
self.input_order = order
|
189
|
-
return x.permute(order)
|
190
|
-
|
191
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
|
192
|
-
|
193
|
-
class Pad(Function):
|
194
|
-
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
195
|
-
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
196
|
-
return x.pad(arg)
|
197
|
-
|
198
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
|
199
|
-
|
200
|
-
class Shrink(Function):
|
201
|
-
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
202
|
-
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
203
|
-
return x.shrink(arg)
|
204
|
-
|
205
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
|
206
|
-
|
207
|
-
class Flip(Function):
|
208
|
-
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
209
|
-
self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
|
210
|
-
return x.stride(self.arg)
|
211
|
-
|
212
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
|