tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: tinygrad
3
- Version: 0.10.0
3
+ Version: 0.10.1
4
4
  Summary: You like pytorch? You like micrograd? You love tinygrad! <3
5
5
  Author: George Hotz
6
6
  License: MIT
@@ -11,26 +11,19 @@ Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Provides-Extra: arm
13
13
  Requires-Dist: unicorn; extra == "arm"
14
- Provides-Extra: docs
15
- Requires-Dist: mkdocs; extra == "docs"
16
- Requires-Dist: mkdocs-material; extra == "docs"
17
- Requires-Dist: mkdocstrings[python]; extra == "docs"
18
- Requires-Dist: markdown-callouts; extra == "docs"
19
- Requires-Dist: markdown-exec[ansi]; extra == "docs"
20
- Requires-Dist: black; extra == "docs"
21
- Requires-Dist: numpy; extra == "docs"
14
+ Provides-Extra: triton
15
+ Requires-Dist: triton-nightly>=2.1.0.dev20231014192330; extra == "triton"
22
16
  Provides-Extra: linting
23
17
  Requires-Dist: pylint; extra == "linting"
24
- Requires-Dist: mypy==1.11.2; extra == "linting"
18
+ Requires-Dist: mypy==1.13.0; extra == "linting"
25
19
  Requires-Dist: typing-extensions; extra == "linting"
26
20
  Requires-Dist: pre-commit; extra == "linting"
27
21
  Requires-Dist: ruff; extra == "linting"
28
22
  Requires-Dist: types-tqdm; extra == "linting"
29
- Provides-Extra: llvm
30
- Requires-Dist: llvmlite; extra == "llvm"
31
23
  Provides-Extra: testing
32
24
  Requires-Dist: numpy; extra == "testing"
33
25
  Requires-Dist: torch; extra == "testing"
26
+ Requires-Dist: jax; extra == "testing"
34
27
  Requires-Dist: pillow; extra == "testing"
35
28
  Requires-Dist: pytest; extra == "testing"
36
29
  Requires-Dist: pytest-xdist; extra == "testing"
@@ -50,11 +43,28 @@ Requires-Dist: hypothesis; extra == "testing"
50
43
  Requires-Dist: nibabel; extra == "testing"
51
44
  Requires-Dist: bottle; extra == "testing"
52
45
  Requires-Dist: ggml-python; extra == "testing"
53
- Provides-Extra: testing_tf
46
+ Requires-Dist: capstone; extra == "testing"
47
+ Provides-Extra: webgpu
48
+ Requires-Dist: wgpu; extra == "webgpu"
49
+ Provides-Extra: docs
50
+ Requires-Dist: mkdocs; extra == "docs"
51
+ Requires-Dist: mkdocs-material; extra == "docs"
52
+ Requires-Dist: mkdocstrings[python]; extra == "docs"
53
+ Requires-Dist: markdown-callouts; extra == "docs"
54
+ Requires-Dist: markdown-exec[ansi]; extra == "docs"
55
+ Requires-Dist: black; extra == "docs"
56
+ Requires-Dist: numpy; extra == "docs"
57
+ Provides-Extra: testing-tf
54
58
  Requires-Dist: tensorflow==2.15.1; extra == "testing-tf"
55
- Requires-Dist: tensorflow-addons; extra == "testing-tf"
56
- Provides-Extra: triton
57
- Requires-Dist: triton-nightly>=2.1.0.dev20231014192330; extra == "triton"
59
+ Requires-Dist: tensorflow_addons; extra == "testing-tf"
60
+ Dynamic: author
61
+ Dynamic: classifier
62
+ Dynamic: description
63
+ Dynamic: description-content-type
64
+ Dynamic: license
65
+ Dynamic: provides-extra
66
+ Dynamic: requires-python
67
+ Dynamic: summary
58
68
 
59
69
  <div align="center">
60
70
 
@@ -146,6 +156,7 @@ tinygrad already supports numerous accelerators, including:
146
156
  - [x] [AMD](tinygrad/runtime/ops_amd.py)
147
157
  - [x] [NV](tinygrad/runtime/ops_nv.py)
148
158
  - [x] [QCOM](tinygrad/runtime/ops_qcom.py)
159
+ - [x] [WEBGPU](tinygrad/runtime/ops_webgpu.py)
149
160
 
150
161
  And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
151
162
 
@@ -183,8 +194,8 @@ y = Tensor([[2.0,0,-2.0]], requires_grad=True)
183
194
  z = y.matmul(x).sum()
184
195
  z.backward()
185
196
 
186
- print(x.grad.numpy()) # dz/dx
187
- print(y.grad.numpy()) # dz/dy
197
+ print(x.grad.tolist()) # dz/dx
198
+ print(y.grad.tolist()) # dz/dy
188
199
  ```
189
200
 
190
201
  The same thing but in PyTorch:
@@ -196,8 +207,8 @@ y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
196
207
  z = y.matmul(x).sum()
197
208
  z.backward()
198
209
 
199
- print(x.grad.numpy()) # dz/dx
200
- print(y.grad.numpy()) # dz/dy
210
+ print(x.grad.tolist()) # dz/dx
211
+ print(y.grad.tolist()) # dz/dy
201
212
  ```
202
213
 
203
214
  ## Contributing
@@ -0,0 +1,86 @@
1
+ tinygrad/__init__.py,sha256=2Jhg7NSWlegCi4OAfGW0iHRVHeqMx09f7446rwAmc60,587
2
+ tinygrad/device.py,sha256=mUrxoZJfqBJepDeqbmS2Y8UbOB_UPERqP9zBN7sCBk8,18968
3
+ tinygrad/dtype.py,sha256=010zGuqXwUoyrhe23nJQ5oHmJt_6CRL7Hcl0kWjvKbo,9843
4
+ tinygrad/gradient.py,sha256=hVyRMnwzjtWwKfF0NMGQovE2I7_2GdKFyR9pWJn4eE4,4280
5
+ tinygrad/helpers.py,sha256=DIofGx-mg-umNKjhNpINcrxZExUwo6JIKHgmFuKNLUM,19203
6
+ tinygrad/ops.py,sha256=T1gmR1ywWPnGEcw-MPAW5SGEvNS_789PtnFrR8rz5H4,70614
7
+ tinygrad/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ tinygrad/spec.py,sha256=Wi8PdgeIHzv6WCTwKT9zGP5_53Bt_QtmvGTjIzTuwi0,8953
9
+ tinygrad/tensor.py,sha256=-DVucOwq8HmM7IHqW7ZNABy7-vYZI-YfF0Vzdo0rs70,181559
10
+ tinygrad/codegen/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ tinygrad/codegen/kernel.py,sha256=4TcuMvvGODd0zFBuhK9KzvQhwzknRgeCL_n9WrOpljQ,42432
12
+ tinygrad/codegen/linearize.py,sha256=X9aOMhvimDTvMq2Vnj3zfGNIRhJnAG5mLcq6EeUfvvU,10382
13
+ tinygrad/codegen/lowerer.py,sha256=DqFg53oALqWG7RygOjXEioyGk52BnGbCzvQSqiOIKGw,7527
14
+ tinygrad/codegen/rewriter.py,sha256=T3tOn-T1IFqiP90GCqfmkr0V5i0knr2tjQfKPVr6zzM,29948
15
+ tinygrad/codegen/transcendental.py,sha256=0qRaEtIoJKDfjPqvQWexShZW3F__wtzfjhU__BqiMD8,13112
16
+ tinygrad/engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ tinygrad/engine/jit.py,sha256=6D8pCpRPZNANwFEHAsAMUgUaEwaCf4mB_1oem6oZZ7o,16630
18
+ tinygrad/engine/memory.py,sha256=UyiNYIoUjtcJ1SX6ApoCnbrSKZpbBbhfwr13TIaiqEM,3231
19
+ tinygrad/engine/multi.py,sha256=FnyCvwFTdBXiydXB0dD5Y63KnerCurZS_o8WOU5fHFM,10371
20
+ tinygrad/engine/realize.py,sha256=vOQ33sTn_P2oT0GRFU5tKYLfiRmxrIIdVs27v0_a40g,9690
21
+ tinygrad/engine/schedule.py,sha256=9JywWlmVo912JmbIF9r5ajr5-Sesqw1EhE-CYnk5LHI,27407
22
+ tinygrad/engine/search.py,sha256=VOsUOPaI7zWtDgL2SJDiKSPaTs3MhjgZOJBZ5_ISeiY,12093
23
+ tinygrad/nn/__init__.py,sha256=BAxMz-g7v-v1A32KaBzmGiaEnvQ_dU9d3KoPdYYwLDQ,15156
24
+ tinygrad/nn/datasets.py,sha256=wcT0Qrlpw_RzM7uBy8uphzKAjGT-ZE48fiP-0g3WvI0,1042
25
+ tinygrad/nn/optim.py,sha256=qfdYKi_ssX5O_DU6h8GJ0WCzBzAZLyyS3p_946PJNsQ,6816
26
+ tinygrad/nn/state.py,sha256=zXFMwAw7sf35C9x3iAxYiPsmhk7_S6qPcX3wXxqp6Bw,16030
27
+ tinygrad/renderer/__init__.py,sha256=KvZ3y7MnqKDHOKoTcxghT95FhTBJuzs3IFEBM077jw8,6989
28
+ tinygrad/renderer/cstyle.py,sha256=TrmxtioR7Lgw_8CK-Ay9EA0HjS1TXRf-RaMsCsuHszw,29852
29
+ tinygrad/renderer/llvmir.py,sha256=iHFjE9-GtH2sGD_feCGe7aGECPHJ5qGOyuuyCDxSEXU,8529
30
+ tinygrad/renderer/ptx.py,sha256=is4PMkdbDDgKsTy67DtkugatdhbtVMijKE9u1f6-0ag,14899
31
+ tinygrad/renderer/wgsl.py,sha256=3AGZp2jvymxLZSJHjzD7zlwxvds36-lSu2ZkntrW_ww,6697
32
+ tinygrad/runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
+ tinygrad/runtime/ops_amd.py,sha256=zRictmcgUPw6d8iykz6QWqCmiPirOh-hTACT-GYxso8,37282
34
+ tinygrad/runtime/ops_clang.py,sha256=InJPVVhY10bxRXOaDQRUBlWafSZRdo72Y0VVW9O6dBw,1310
35
+ tinygrad/runtime/ops_cloud.py,sha256=1umLappzYKg0ECC9C_EN1lzXquAQVqY7wj-mS0vTcSk,10362
36
+ tinygrad/runtime/ops_cuda.py,sha256=26ko2bTLT_jDmU2i8_xiL4BomD9krzvlz6wY5FE-e5c,7138
37
+ tinygrad/runtime/ops_disk.py,sha256=3eTpJSVQQH32KBa7tRLBm2pPSkCaIC1adPc748UkUeg,6619
38
+ tinygrad/runtime/ops_dsp.py,sha256=JCSehSDLzaa8Nhd6Xtas0d5vOCgwQGY4U3awHOLThEI,16900
39
+ tinygrad/runtime/ops_gpu.py,sha256=VrY9iM5i44OQwTJE8sgUavCtr1TOpE96I5-JmqUFz-E,8972
40
+ tinygrad/runtime/ops_hip.py,sha256=MbR4depxgHcaGpOJGvUCiLq5tdbpaDiqs-Xj41rY2xQ,3730
41
+ tinygrad/runtime/ops_llvm.py,sha256=yjW5wVMJdmayeeTQwnQ60mtOWiiYEqvtb9Rt5cDj2rU,3239
42
+ tinygrad/runtime/ops_metal.py,sha256=rz3fiaiB1by_RuP7VFVBgGA1mye43_-6PIfMfeQ3YmU,13229
43
+ tinygrad/runtime/ops_npy.py,sha256=8VNf1S5M_MRk9d3GxSsTPbfEz7I_aOwl7QMZ1mUG3As,370
44
+ tinygrad/runtime/ops_nv.py,sha256=8NvxtEc89dbJ5xgVhGs9_zeGMztp9HRZHLwPqylDsdQ,34541
45
+ tinygrad/runtime/ops_python.py,sha256=5CZek-7fMhq6w07lGDM0pXltZzzMMRtuU2x5PeWMyWY,11681
46
+ tinygrad/runtime/ops_qcom.py,sha256=Dt4hgAd6o13CxsOFRSt7lHY6bCOTLvtQpOr_jx_lYbc,22565
47
+ tinygrad/runtime/ops_webgpu.py,sha256=-uBAWhVZ7T_9zG2v6PWjkJLpovrClqb_XAPk-l83ryc,4345
48
+ tinygrad/runtime/autogen/adreno.py,sha256=u7VxIomPAlW3nFUs4gSTe-6ijam_ywkvDM9OuTLF-j8,897915
49
+ tinygrad/runtime/autogen/amd_gpu.py,sha256=Iasq-zYiv8bvT43dtvPO1W5jaLEQ3d6hP0CoFVhSsak,3977783
50
+ tinygrad/runtime/autogen/comgr.py,sha256=3pp3XyqEJDBLa9XtGx2-Gc1iJgBbbgIq4pdFEpYXT44,39874
51
+ tinygrad/runtime/autogen/cuda.py,sha256=N0QyaMvQumr_HZh7fusCHM1d4o4mYti3Wq1MN7JSKr8,243920
52
+ tinygrad/runtime/autogen/hip.py,sha256=1yUHDCwL3KkD15if2Q1Ud3GbJiR7DxsNorKZTCINw54,245532
53
+ tinygrad/runtime/autogen/hsa.py,sha256=7Hsrn17HmChyeFOSX_3Fnzl9c0COtq2Z2ExqGu5FNiU,277716
54
+ tinygrad/runtime/autogen/io_uring.py,sha256=ZIZ2YnQkLr8WIHMieBw9Dv-NZ1ar9TwdP4YBv3gJm28,59786
55
+ tinygrad/runtime/autogen/kfd.py,sha256=VdhuG4qec0EgM-jJmWcdTS-8WrmywNkcjSX7ibbmvdk,30866
56
+ tinygrad/runtime/autogen/kgsl.py,sha256=2EgJ5Kst4oRUv81hsV2srgwPvWpY-weaSB4E2lGMAyc,50656
57
+ tinygrad/runtime/autogen/libc.py,sha256=xKJk2hCzVpauJSc8wCQis5x3SwcXnDli7_HyRUqEGRc,197318
58
+ tinygrad/runtime/autogen/libpciaccess.py,sha256=zaVmkoUHVTEQcPQwkFpgMCgX4HtX-BL-MGMbi8XtgCI,84194
59
+ tinygrad/runtime/autogen/llvm.py,sha256=aeVd_ByohxbGRyqXzShPOupI2xtcdk34I6_OIBrMQHg,467606
60
+ tinygrad/runtime/autogen/nv_gpu.py,sha256=9X2tPdv2E5JmXGZeT8i9jL19YJ4ETTsYwfU_Wn9mTwc,1679326
61
+ tinygrad/runtime/autogen/nvrtc.py,sha256=19te2-TW5suFy85KnJox3CPOmeeml5YxqIDeL-Bx_m4,23132
62
+ tinygrad/runtime/autogen/opencl.py,sha256=NL6fa8P3KC_McNZ8g2babdr3b8vrY-bFK0qzNAtL-rE,82656
63
+ tinygrad/runtime/autogen/qcom_dsp.py,sha256=jx36-zC6reTuWgfbHCrKVjOZcF4Q9fBnq3CuTbxztQk,61848
64
+ tinygrad/runtime/autogen/vfio.py,sha256=IJV1eeWWllU6b9LAX_IH0bUW5NDzfhPQy_YzXGhD9-8,32431
65
+ tinygrad/runtime/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
66
+ tinygrad/runtime/graph/cuda.py,sha256=vLjT_c93G6ia_1MsbYWP5Uq96Aeko0AOskRkwT5-MUI,4818
67
+ tinygrad/runtime/graph/hcq.py,sha256=kKu2YnjAZU40XMACSbbcxJSi2xdTg3OYLO2zcPLyAf0,12600
68
+ tinygrad/runtime/graph/metal.py,sha256=6JN7WJr9w1pIvQGYr0Rsnyg-SA-slK61jKV__KmHnXg,6001
69
+ tinygrad/runtime/support/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
70
+ tinygrad/runtime/support/allocator.py,sha256=INW6TaJxMi4cwMEDtkta2wFH3kJ87wy3SyFFu3_jJ9w,4721
71
+ tinygrad/runtime/support/compiler_cuda.py,sha256=6cU1OMMW3aOUFNVALpDYWKXh-zFc5q81PSSQhRK9fLw,5471
72
+ tinygrad/runtime/support/compiler_hip.py,sha256=fbRP82UdG4T-KCRYH_H2hEXlMFeHIJntSnY35ZWE5JY,4398
73
+ tinygrad/runtime/support/elf.py,sha256=AxWyaAVEe4xdSTiISqIf80oaHRwUZwb5b1V_Q3874s8,3857
74
+ tinygrad/runtime/support/hcq.py,sha256=JWLRwoGPZCgRBryNb6uq0qCcfkVfMmNnlnHmCzPFxFI,21901
75
+ tinygrad/runtime/support/llvm.py,sha256=-qI8NRmSx2dBZnG-OQCTWfmy0XySj_T8kHNRHULTyG4,2174
76
+ tinygrad/runtime/support/am/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
+ tinygrad/runtime/support/am/amdev.py,sha256=LD80Bc9DHdl99bxrfwn2c43m1NdDD0mjPnSOJkEfkgU,20433
78
+ tinygrad/runtime/support/am/ip.py,sha256=WnSIZWSG9IvSjVYLJv-VShW_X6vcziynj__lvTOt4yQ,24531
79
+ tinygrad/shape/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
80
+ tinygrad/shape/shapetracker.py,sha256=clrYjN-zEZHSh_biDy3jnHL4Fq9HyWCdH_CRpwLKki0,7612
81
+ tinygrad/shape/view.py,sha256=7KJwP2lS1YWhMq80Ka6_h_qavyxSEm9jWloVgHYRx-k,18110
82
+ tinygrad-0.10.1.dist-info/LICENSE,sha256=ABRhUPEILzINYIukgazD-_rPipkUNUwslrb0RxnV6Xc,1058
83
+ tinygrad-0.10.1.dist-info/METADATA,sha256=JMb7PpBcZqsf0tbmNJzvBqUrH6U1JER_0mz20UMo0LA,11241
84
+ tinygrad-0.10.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
85
+ tinygrad-0.10.1.dist-info/top_level.txt,sha256=vDABMCWBFQnx2kn9Azueu88FP-1klQdePoHikQhHymc,9
86
+ tinygrad-0.10.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.5.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
tinygrad/engine/lazy.py DELETED
@@ -1,228 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Optional, Any, Tuple, List, get_args
3
- from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
4
- from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
5
- from tinygrad.ops import exec_alu, python_alu
6
- from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
7
- from tinygrad.shape.shapetracker import ShapeTracker
8
- from tinygrad.device import Buffer
9
- from weakref import ref, ReferenceType, WeakValueDictionary
10
-
11
- lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
12
- def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
13
- base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
14
- if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
15
- dtype = to_dtype(dtype)
16
- if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
17
-
18
- cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
19
- if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
20
-
21
- ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
22
- if enable_cache: lazycache[cache_key] = ret
23
- return ret
24
-
25
- view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
26
- class LazyBuffer(MathTrait):
27
- def __init__(self, device:str, st:ShapeTracker, dtype:DType,
28
- op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
29
- base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
30
- self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
31
- self._base: Optional[LazyBuffer] = None
32
- if base is None:
33
- # properties on base
34
- self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
35
- assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
36
-
37
- if self.op is Ops.BUFFER_VIEW:
38
- # some LazyBuffers can be processed with only a view, no AST required
39
- self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
40
- else:
41
- self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
42
- self.buffer.ref(1)
43
- self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
44
- self.forced_realize = False
45
- else:
46
- # properties on view
47
- assert base.base == base, "base must be a base itself"
48
- self._base = base
49
-
50
- def __del__(self):
51
- if hasattr(self, 'buffer'): self.buffer.ref(-1)
52
-
53
- def __repr__(self) -> str:
54
- return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base is not self else (self.op, self.realized)}>"
55
-
56
- @property
57
- def realized(self) -> Optional[Buffer]:
58
- # NOTE: we check for a lack of srcs instead of an allocated buffer to make unrealized assigns return None here
59
- return self.buffer if self._base is None and not hasattr(self, 'srcs') else None
60
-
61
- # NOTE: this has to be a function to prevent self reference
62
- @property
63
- def base(self) -> LazyBuffer: return self._base if self._base is not None else self
64
-
65
- # same API as multi
66
- @property
67
- def lbs(self) -> List[LazyBuffer]: return [self]
68
-
69
- @staticmethod
70
- def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
71
- assert isinstance(src, tuple)
72
- return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
73
-
74
- def const_like(self, b): return self.const_with_shape(b, self.shape)
75
- def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer:
76
- assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
77
- return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
78
-
79
- @property
80
- def is_realized(self) -> bool: return self.base.realized is not None
81
-
82
- def assign(self, x:LazyBuffer) -> LazyBuffer:
83
- assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
84
- assert self.is_realized, f"assign target must be realized {self}"
85
- return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
86
- src=(self.base, x), enable_cache=True)
87
-
88
- def can_view(self):
89
- return (self.st.consecutive and not self.is_unrealized_const() and not isinstance(self.dtype, ImageDType) and
90
- self.device.split(":")[0] in view_supported_devices)
91
-
92
- def contiguous(self, allow_buffer_view=True):
93
- if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
94
- ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS)
95
- if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
96
- return ret
97
- self.base.forced_realize = True
98
- return self
99
-
100
- def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True)
101
- def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
102
- if self.dtype == dtype: return self
103
- if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
104
- if self.is_unrealized_unmasked_const() and not bitcast:
105
- return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype))
106
- new_shape = self.shape
107
- if bitcast and self.dtype.itemsize != dtype.itemsize:
108
- if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
109
- if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
110
- # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
111
- if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
112
- new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
113
- elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
114
- # TODO: applying this makes gpt2 slower
115
- return self.base.cast(dtype, bitcast)._view(self.st)
116
- cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
117
- return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
118
-
119
- def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
120
- def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
121
-
122
- def _copy(self, device:str) -> LazyBuffer:
123
- assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
124
- return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
125
-
126
- def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
127
- # no COPY
128
- if self.device == device and not clone: return self
129
-
130
- # double COPY = one COPY
131
- if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY:
132
- return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
133
-
134
- # const doesn't have to be copied (issues with disk tensor)
135
- if self.is_unrealized_const():
136
- return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
137
-
138
- # if it's a shrink, do the shrink before the copy with CONTIGUOUS
139
- if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
140
-
141
- # copy the base and apply the shapetracker on the new device
142
- return self.base._copy(device)._view(self.st)
143
-
144
- def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
145
-
146
- def alu(self, op:Ops, *in_srcs:LazyBuffer) -> LazyBuffer:
147
- srcs: List[LazyBuffer] = []
148
- for s in (self,)+in_srcs:
149
- if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
150
- srcs.append(root._view(s.base.contiguous_child[1]))
151
- else:
152
- srcs.append(s)
153
- if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
154
- raise AssertionError(f"all dtypes must match {dts} on {op}")
155
- assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
156
- if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool"
157
-
158
- out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype
159
-
160
- # const folding
161
- if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
162
- return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
163
- if op in GroupOp.Binary:
164
- x, y = self, in_srcs[0]
165
- if op is Ops.ADD:
166
- if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
167
- if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
168
- if op is Ops.MUL:
169
- if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0)
170
- if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0)
171
- if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
172
-
173
- return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs))
174
-
175
- # *** reduce ops ***
176
-
177
- def _reduce_op(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
178
- assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
179
- axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
180
- if len(axis) == 0: return self
181
- return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, Ops.REDUCE_AXIS, (op, axis), (self,))
182
-
183
- def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
184
- new_shape = self.st.reduce(axis)
185
- # TODO: this logic should move to the scheduler
186
- if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
187
-
188
- # const folding
189
- # TODO: fold this for symbolic?
190
- if self.is_unrealized_unmasked_const() and all_int(self.shape):
191
- if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
192
- if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
193
- if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
194
-
195
- # TODO: can we split symbolic shape if the reduce axis is not symbolic?
196
- if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
197
- prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
198
- return self._reduce_op(op, axis)
199
-
200
- # if there are few globals, make some reduces into globals by splitting into two kernels
201
- # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
202
- # ~2**10 should be enough if GROUP is used
203
- # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
204
- # split is moved to the end to provide maximum locality for the second phase reduce.
205
- self_real_strides = self.st.real_strides(ignore_valid=True)
206
- split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
207
- if self.shape[i] % x == 0 and self_real_strides[i] != 0]
208
- if not split_candidates: return self._reduce_op(op, axis)
209
- dim_to_split, divisor = split_candidates[0]
210
- splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
211
- splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
212
- if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
213
- return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
214
-
215
- # *** movement ops ***
216
-
217
- def _view(self, new_st:ShapeTracker) -> LazyBuffer:
218
- if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
219
- return self.const_with_shape(0, new_st.shape)
220
- if new_st.contiguous and self.base.shape == new_st.shape: return self.base
221
- return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
222
-
223
- def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
224
- def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
225
- def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
226
- def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
227
- def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
228
- def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
tinygrad/function.py DELETED
@@ -1,212 +0,0 @@
1
- """This is where the forwards and backwards passes live."""
2
- import math
3
- from typing import Tuple, Optional
4
- from tinygrad.helpers import argsort
5
- from tinygrad.dtype import dtypes, DType, sum_acc_dtype
6
- from tinygrad.ops import Ops, resolve, sint
7
- from tinygrad.tensor import Function
8
- from tinygrad.engine.lazy import LazyBuffer
9
-
10
- class Contiguous(Function):
11
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
12
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
13
-
14
- class ContiguousBackward(Function):
15
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x
16
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
17
-
18
- class Cast(Function):
19
- def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
20
- self.input_dtype, self.bitcast = x.dtype, bitcast
21
- return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
22
-
23
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
24
- if self.bitcast: raise RuntimeError("bitcast cannot backward")
25
- return grad_output.cast(self.input_dtype)
26
-
27
- # ************* unary ops *************
28
-
29
- class Reciprocal(Function):
30
- def forward(self, x:LazyBuffer) -> LazyBuffer:
31
- self.ret = x.reciprocal()
32
- return self.ret
33
-
34
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return -grad_output * self.ret * self.ret
35
-
36
- class Sin(Function):
37
- def forward(self, x:LazyBuffer) -> LazyBuffer:
38
- self.x = x
39
- return x.sin()
40
-
41
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return (math.pi/2 - self.x).sin() * grad_output
42
-
43
- class Relu(Function):
44
- def forward(self, x:LazyBuffer) -> LazyBuffer:
45
- self.ret = x.maximum(0)
46
- return self.ret
47
-
48
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.gt(0).cast(grad_output.dtype) * grad_output
49
-
50
- class Log(Function):
51
- def forward(self, x:LazyBuffer) -> LazyBuffer:
52
- self.x = x
53
- return x.log2() * math.log(2)
54
-
55
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / self.x
56
-
57
- class Exp(Function):
58
- def forward(self, x:LazyBuffer) -> LazyBuffer:
59
- self.ret = (x * (1/math.log(2))).exp2()
60
- return self.ret
61
-
62
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret * grad_output
63
-
64
- class Sqrt(Function):
65
- def forward(self, x:LazyBuffer) -> LazyBuffer:
66
- self.ret = x.sqrt()
67
- return self.ret
68
-
69
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / (self.ret*2)
70
-
71
- # NOTE: the implicit derivative of sigmoid is not stable
72
- # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
73
- # TODO: have the backend automatically find this
74
- class Sigmoid(Function):
75
- def forward(self, x:LazyBuffer) -> LazyBuffer:
76
- self.ret = (1 + (x * (-1/math.log(2))).exp2()).reciprocal()
77
- return self.ret
78
-
79
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
80
- return (self.ret * (1 - self.ret)) * grad_output
81
-
82
- class Sign(Function):
83
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
84
- # backward always return 0 to match torch
85
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const_like(0)
86
-
87
- # ************* binary ops *************
88
-
89
- class Less(Function):
90
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.lt(y)
91
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
92
-
93
- class Neq(Function):
94
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.ne(y)
95
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
96
-
97
- class Xor(Function):
98
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x^y
99
-
100
- class BitwiseAnd(Function):
101
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x&y
102
-
103
- class BitwiseOr(Function):
104
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x|y
105
-
106
- class Threefry(Function):
107
- def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.threefry(seed)
108
-
109
- class Add(Function):
110
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x+y
111
-
112
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
113
- return grad_output if self.needs_input_grad[0] else None, \
114
- grad_output if self.needs_input_grad[1] else None
115
-
116
- class Mul(Function):
117
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
118
- self.x, self.y = x, y
119
- return x * y
120
-
121
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
122
- return (self.y * grad_output) if self.needs_input_grad[0] else None, \
123
- (self.x * grad_output) if self.needs_input_grad[1] else None
124
-
125
- class IDiv(Function):
126
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x // y
127
-
128
- # ************* ternary ops *************
129
-
130
- class Where(Function):
131
- def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
132
- self.x = x
133
- return self.x.where(y, z)
134
-
135
- def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
136
- return None, \
137
- self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
138
- self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
139
-
140
- # ************* reduce ops *************
141
-
142
- class Sum(Function):
143
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
144
- self.input_shape = x.shape
145
- return x.r(Ops.ADD, axis)
146
-
147
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
148
-
149
- class Prod(Function):
150
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
151
- self.x, self.ret = x, x.r(Ops.MUL, axis)
152
- return self.ret
153
-
154
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
155
- return (grad_output * self.ret).expand(self.x.shape) / self.x
156
-
157
- class Max(Function):
158
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
159
- self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
160
- return self.ret
161
-
162
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
163
- # 1s in locations where the max was chosen (can be two locations)
164
- max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
165
- div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
166
- return (max_is_1s/div) * grad_output.expand(self.x.shape)
167
-
168
- # ************* movement ops *************
169
-
170
- # NOTE: this is sum in reverse
171
- class Expand(Function):
172
- def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
173
- self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
174
- return x.expand(shape)
175
-
176
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
177
- return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
178
-
179
- class Reshape(Function):
180
- def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
181
- self.input_shape = x.shape
182
- return x.reshape(shape)
183
-
184
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
185
-
186
- class Permute(Function):
187
- def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
188
- self.input_order = order
189
- return x.permute(order)
190
-
191
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
192
-
193
- class Pad(Function):
194
- def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
195
- self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
196
- return x.pad(arg)
197
-
198
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
199
-
200
- class Shrink(Function):
201
- def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
202
- self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
203
- return x.shrink(arg)
204
-
205
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
206
-
207
- class Flip(Function):
208
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
209
- self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
210
- return x.stride(self.arg)
211
-
212
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)