tinygrad 0.8.0__tar.gz → 0.9.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. {tinygrad-0.8.0 → tinygrad-0.9.1}/LICENSE +1 -1
  2. {tinygrad-0.8.0 → tinygrad-0.9.1}/PKG-INFO +31 -13
  3. {tinygrad-0.8.0 → tinygrad-0.9.1}/README.md +17 -9
  4. {tinygrad-0.8.0 → tinygrad-0.9.1}/setup.py +20 -5
  5. tinygrad-0.9.1/test/test_arange.py +19 -0
  6. tinygrad-0.9.1/test/test_assign.py +378 -0
  7. tinygrad-0.9.1/test/test_const_folding.py +258 -0
  8. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_conv.py +10 -12
  9. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_conv_shapetracker.py +6 -7
  10. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_copy_speed.py +4 -4
  11. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_custom_function.py +10 -8
  12. tinygrad-0.9.1/test/test_device_speed.py +38 -0
  13. tinygrad-0.9.1/test/test_dtype.py +718 -0
  14. tinygrad-0.9.1/test/test_dtype_alu.py +165 -0
  15. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_fusion_op.py +5 -6
  16. tinygrad-0.9.1/test/test_fuzz_shape_ops.py +88 -0
  17. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_gc.py +8 -8
  18. tinygrad-0.9.1/test/test_graph.py +235 -0
  19. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_image_dtype.py +2 -2
  20. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_jit.py +168 -72
  21. tinygrad-0.9.1/test/test_lazybuffer.py +117 -0
  22. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_lazyop.py +3 -2
  23. tinygrad-0.9.1/test/test_linearizer.py +1778 -0
  24. tinygrad-0.9.1/test/test_linearizer_failures.py +255 -0
  25. tinygrad-0.9.1/test/test_linearizer_overflows.py +89 -0
  26. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_method_cache.py +1 -2
  27. tinygrad-0.9.1/test/test_multitensor.py +841 -0
  28. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_net_speed.py +1 -1
  29. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_nn.py +182 -80
  30. tinygrad-0.9.1/test/test_ops.py +1843 -0
  31. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_optim.py +31 -4
  32. tinygrad-0.9.1/test/test_pattern_matcher.py +168 -0
  33. tinygrad-0.9.1/test/test_pickle.py +70 -0
  34. tinygrad-0.9.1/test/test_print_tree.py +66 -0
  35. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_randomness.py +75 -14
  36. tinygrad-0.9.1/test/test_schedule.py +1156 -0
  37. tinygrad-0.9.1/test/test_search.py +101 -0
  38. tinygrad-0.9.1/test/test_setitem.py +138 -0
  39. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_specific_conv.py +5 -5
  40. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_speed_v_torch.py +3 -5
  41. tinygrad-0.9.1/test/test_subbuffer.py +52 -0
  42. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_symbolic_jit.py +115 -5
  43. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_symbolic_ops.py +58 -3
  44. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_symbolic_shapetracker.py +56 -4
  45. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_tensor.py +229 -62
  46. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_tensor_data.py +9 -0
  47. tinygrad-0.9.1/test/test_tensor_variable.py +74 -0
  48. tinygrad-0.9.1/test/test_uop_graph.py +190 -0
  49. tinygrad-0.9.1/test/test_uops.py +319 -0
  50. tinygrad-0.9.1/test/test_uops_stats.py +81 -0
  51. tinygrad-0.9.1/test/test_verify_lazyop.py +64 -0
  52. tinygrad-0.9.1/test/test_winograd.py +72 -0
  53. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_zero_copy.py +1 -1
  54. tinygrad-0.9.1/tinygrad/__init__.py +6 -0
  55. tinygrad-0.9.1/tinygrad/codegen/__init__.py +0 -0
  56. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad/codegen/kernel.py +253 -225
  57. tinygrad-0.9.1/tinygrad/codegen/linearizer.py +528 -0
  58. tinygrad-0.9.1/tinygrad/codegen/uops.py +451 -0
  59. tinygrad-0.9.1/tinygrad/device.py +320 -0
  60. tinygrad-0.9.1/tinygrad/dtype.py +113 -0
  61. tinygrad-0.9.1/tinygrad/engine/__init__.py +0 -0
  62. tinygrad-0.9.1/tinygrad/engine/graph.py +100 -0
  63. tinygrad-0.9.1/tinygrad/engine/jit.py +198 -0
  64. tinygrad-0.9.1/tinygrad/engine/realize.py +192 -0
  65. tinygrad-0.9.1/tinygrad/engine/schedule.py +370 -0
  66. tinygrad-0.9.1/tinygrad/engine/search.py +199 -0
  67. tinygrad-0.8.0/tinygrad/mlops.py → tinygrad-0.9.1/tinygrad/function.py +40 -32
  68. tinygrad-0.9.1/tinygrad/helpers.py +310 -0
  69. tinygrad-0.9.1/tinygrad/lazy.py +220 -0
  70. tinygrad-0.9.1/tinygrad/multi.py +173 -0
  71. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad/nn/__init__.py +180 -9
  72. tinygrad-0.9.1/tinygrad/nn/datasets.py +8 -0
  73. tinygrad-0.9.1/tinygrad/nn/optim.py +150 -0
  74. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad/nn/state.py +87 -19
  75. tinygrad-0.9.1/tinygrad/ops.py +169 -0
  76. tinygrad-0.9.1/tinygrad/renderer/__init__.py +65 -0
  77. tinygrad-0.9.1/tinygrad/renderer/assembly.py +269 -0
  78. tinygrad-0.9.1/tinygrad/renderer/cstyle.py +389 -0
  79. tinygrad-0.9.1/tinygrad/renderer/llvmir.py +160 -0
  80. tinygrad-0.9.1/tinygrad/runtime/__init__.py +0 -0
  81. tinygrad-0.9.1/tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  82. tinygrad-0.9.1/tinygrad/runtime/autogen/comgr.py +891 -0
  83. tinygrad-0.9.1/tinygrad/runtime/autogen/cuda.py +5923 -0
  84. tinygrad-0.9.1/tinygrad/runtime/autogen/hip.py +5909 -0
  85. tinygrad-0.9.1/tinygrad/runtime/autogen/hsa.py +5893 -0
  86. tinygrad-0.9.1/tinygrad/runtime/autogen/io_uring.py +1486 -0
  87. tinygrad-0.9.1/tinygrad/runtime/autogen/kfd.py +812 -0
  88. tinygrad-0.9.1/tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  89. tinygrad-0.9.1/tinygrad/runtime/autogen/opencl.py +1795 -0
  90. tinygrad-0.9.1/tinygrad/runtime/driver/__init__.py +0 -0
  91. tinygrad-0.9.1/tinygrad/runtime/driver/hip_comgr.py +56 -0
  92. tinygrad-0.9.1/tinygrad/runtime/graph/__init__.py +0 -0
  93. tinygrad-0.9.1/tinygrad/runtime/graph/clang.py +39 -0
  94. tinygrad-0.9.1/tinygrad/runtime/graph/cuda.py +81 -0
  95. tinygrad-0.9.1/tinygrad/runtime/graph/hcq.py +187 -0
  96. tinygrad-0.9.1/tinygrad/runtime/graph/metal.py +75 -0
  97. tinygrad-0.9.1/tinygrad/runtime/ops_amd.py +550 -0
  98. tinygrad-0.9.1/tinygrad/runtime/ops_clang.py +28 -0
  99. tinygrad-0.9.1/tinygrad/runtime/ops_cuda.py +185 -0
  100. tinygrad-0.9.1/tinygrad/runtime/ops_disk.py +125 -0
  101. tinygrad-0.9.1/tinygrad/runtime/ops_gpu.py +103 -0
  102. tinygrad-0.9.1/tinygrad/runtime/ops_llvm.py +46 -0
  103. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad/runtime/ops_metal.py +41 -24
  104. tinygrad-0.9.1/tinygrad/runtime/ops_npy.py +9 -0
  105. tinygrad-0.9.1/tinygrad/runtime/ops_nv.py +625 -0
  106. tinygrad-0.9.1/tinygrad/runtime/ops_python.py +208 -0
  107. tinygrad-0.9.1/tinygrad/shape/__init__.py +0 -0
  108. tinygrad-0.9.1/tinygrad/shape/shapetracker.py +121 -0
  109. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad/shape/symbolic.py +99 -98
  110. tinygrad-0.9.1/tinygrad/shape/view.py +311 -0
  111. tinygrad-0.9.1/tinygrad/tensor.py +2962 -0
  112. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad.egg-info/PKG-INFO +31 -13
  113. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad.egg-info/SOURCES.txt +48 -12
  114. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad.egg-info/requires.txt +15 -3
  115. tinygrad-0.8.0/test/test_assign.py +0 -66
  116. tinygrad-0.8.0/test/test_dtype.py +0 -467
  117. tinygrad-0.8.0/test/test_dtype_alu.py +0 -157
  118. tinygrad-0.8.0/test/test_hip_rdna3.py +0 -76
  119. tinygrad-0.8.0/test/test_lazybuffer.py +0 -54
  120. tinygrad-0.8.0/test/test_linearizer.py +0 -609
  121. tinygrad-0.8.0/test/test_linearizer_failures.py +0 -109
  122. tinygrad-0.8.0/test/test_multitensor.py +0 -198
  123. tinygrad-0.8.0/test/test_ops.py +0 -1486
  124. tinygrad-0.8.0/test/test_schedule.py +0 -425
  125. tinygrad-0.8.0/test/test_search.py +0 -20
  126. tinygrad-0.8.0/test/test_uops.py +0 -110
  127. tinygrad-0.8.0/test/test_winograd.py +0 -38
  128. tinygrad-0.8.0/tinygrad/__init__.py +0 -6
  129. tinygrad-0.8.0/tinygrad/codegen/linearizer.py +0 -566
  130. tinygrad-0.8.0/tinygrad/device.py +0 -326
  131. tinygrad-0.8.0/tinygrad/dtype.py +0 -97
  132. tinygrad-0.8.0/tinygrad/features/image.py +0 -93
  133. tinygrad-0.8.0/tinygrad/features/multi.py +0 -103
  134. tinygrad-0.8.0/tinygrad/features/search.py +0 -160
  135. tinygrad-0.8.0/tinygrad/graph.py +0 -106
  136. tinygrad-0.8.0/tinygrad/helpers.py +0 -212
  137. tinygrad-0.8.0/tinygrad/jit.py +0 -152
  138. tinygrad-0.8.0/tinygrad/lazy.py +0 -319
  139. tinygrad-0.8.0/tinygrad/nn/optim.py +0 -72
  140. tinygrad-0.8.0/tinygrad/ops.py +0 -110
  141. tinygrad-0.8.0/tinygrad/realize.py +0 -50
  142. tinygrad-0.8.0/tinygrad/renderer/cstyle.py +0 -291
  143. tinygrad-0.8.0/tinygrad/renderer/llvmir.py +0 -165
  144. tinygrad-0.8.0/tinygrad/runtime/graph/cuda.py +0 -76
  145. tinygrad-0.8.0/tinygrad/runtime/graph/hip.py +0 -24
  146. tinygrad-0.8.0/tinygrad/runtime/graph/metal.py +0 -79
  147. tinygrad-0.8.0/tinygrad/runtime/ops_clang.py +0 -26
  148. tinygrad-0.8.0/tinygrad/runtime/ops_cpu.py +0 -45
  149. tinygrad-0.8.0/tinygrad/runtime/ops_cuda.py +0 -93
  150. tinygrad-0.8.0/tinygrad/runtime/ops_disk.py +0 -57
  151. tinygrad-0.8.0/tinygrad/runtime/ops_gpu.py +0 -101
  152. tinygrad-0.8.0/tinygrad/runtime/ops_hip.py +0 -97
  153. tinygrad-0.8.0/tinygrad/runtime/ops_llvm.py +0 -66
  154. tinygrad-0.8.0/tinygrad/runtime/ops_torch.py +0 -49
  155. tinygrad-0.8.0/tinygrad/shape/shapetracker.py +0 -182
  156. tinygrad-0.8.0/tinygrad/shape/view.py +0 -194
  157. tinygrad-0.8.0/tinygrad/tensor.py +0 -953
  158. {tinygrad-0.8.0 → tinygrad-0.9.1}/setup.cfg +0 -0
  159. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_kernel_cache.py +0 -0
  160. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_masked_st.py +0 -0
  161. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_sample.py +0 -0
  162. {tinygrad-0.8.0 → tinygrad-0.9.1}/test/test_to_numpy.py +0 -0
  163. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad.egg-info/dependency_links.txt +0 -0
  164. {tinygrad-0.8.0 → tinygrad-0.9.1}/tinygrad.egg-info/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- Copyright (c) 2023 George Hotz
1
+ Copyright (c) 2024, the tiny corp
2
2
 
3
3
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
4
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tinygrad
3
- Version: 0.8.0
3
+ Version: 0.9.1
4
4
  Summary: You like pytorch? You like micrograd? You love tinygrad! <3
5
5
  Author: George Hotz
6
6
  License: MIT
@@ -10,8 +10,6 @@ Requires-Python: >=3.8
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Requires-Dist: numpy
13
- Requires-Dist: tqdm
14
- Requires-Dist: gpuctypes
15
13
  Requires-Dist: pyobjc-framework-Metal; platform_system == "Darwin"
16
14
  Requires-Dist: pyobjc-framework-libdispatch; platform_system == "Darwin"
17
15
  Provides-Extra: llvm
@@ -32,10 +30,11 @@ Requires-Dist: torch; extra == "testing"
32
30
  Requires-Dist: pillow; extra == "testing"
33
31
  Requires-Dist: pytest; extra == "testing"
34
32
  Requires-Dist: pytest-xdist; extra == "testing"
35
- Requires-Dist: onnx==1.15.0; extra == "testing"
33
+ Requires-Dist: onnx==1.16.0; extra == "testing"
36
34
  Requires-Dist: onnx2torch; extra == "testing"
37
35
  Requires-Dist: opencv-python; extra == "testing"
38
36
  Requires-Dist: tabulate; extra == "testing"
37
+ Requires-Dist: tqdm; extra == "testing"
39
38
  Requires-Dist: safetensors; extra == "testing"
40
39
  Requires-Dist: transformers; extra == "testing"
41
40
  Requires-Dist: sentencepiece; extra == "testing"
@@ -43,16 +42,30 @@ Requires-Dist: tiktoken; extra == "testing"
43
42
  Requires-Dist: librosa; extra == "testing"
44
43
  Requires-Dist: networkx; extra == "testing"
45
44
  Requires-Dist: hypothesis; extra == "testing"
45
+ Requires-Dist: nibabel; extra == "testing"
46
+ Requires-Dist: bottle; extra == "testing"
47
+ Provides-Extra: docs
48
+ Requires-Dist: mkdocs-material; extra == "docs"
49
+ Requires-Dist: mkdocstrings[python]; extra == "docs"
50
+ Requires-Dist: markdown-callouts; extra == "docs"
51
+ Requires-Dist: markdown-exec[ansi]; extra == "docs"
52
+ Requires-Dist: black; extra == "docs"
53
+ Provides-Extra: testing-tf
54
+ Requires-Dist: tensorflow==2.15.1; extra == "testing-tf"
55
+ Requires-Dist: tensorflow_addons; extra == "testing-tf"
46
56
 
47
57
  <div align="center">
48
58
 
49
- [![logo](https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png)](https://tinygrad.org)
59
+ <picture>
60
+ <source media="(prefers-color-scheme: light)" srcset="/docs/logo_tiny_light.svg">
61
+ <img alt="tiny corp logo" src="/docs/logo_tiny_dark.svg" width="50%" height="50%">
62
+ </picture>
50
63
 
51
64
  tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) and [karpathy/micrograd](https://github.com/karpathy/micrograd). Maintained by [tiny corp](https://tinygrad.org).
52
65
 
53
66
  <h3>
54
67
 
55
- [Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](/docs) | [Examples](/examples) | [Showcase](/docs/showcase.md) | [Discord](https://discord.gg/ZjZadyC7PK)
68
+ [Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](https://docs.tinygrad.org/) | [Discord](https://discord.gg/ZjZadyC7PK)
56
69
 
57
70
  </h3>
58
71
 
@@ -122,17 +135,15 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers
122
135
 
123
136
  tinygrad already supports numerous accelerators, including:
124
137
 
125
- - [x] [CPU](tinygrad/runtime/ops_cpu.py)
126
138
  - [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
127
- - [x] [C Code (Clang)](tinygrad/runtime/ops_clang.py)
139
+ - [x] [CLANG (C Code)](tinygrad/runtime/ops_clang.py)
128
140
  - [x] [LLVM](tinygrad/runtime/ops_llvm.py)
129
141
  - [x] [METAL](tinygrad/runtime/ops_metal.py)
130
142
  - [x] [CUDA](tinygrad/runtime/ops_cuda.py)
131
- - [x] [PyTorch](tinygrad/runtime/ops_torch.py)
132
- - [x] [HIP](tinygrad/runtime/ops_hip.py)
143
+ - [x] [AMD](tinygrad/runtime/ops_amd.py)
144
+ - [x] [NV](tinygrad/runtime/ops_nv.py)
133
145
 
134
146
  And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
135
- More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md).
136
147
 
137
148
  ## Installation
138
149
 
@@ -154,7 +165,7 @@ python3 -m pip install git+https://github.com/tinygrad/tinygrad.git
154
165
 
155
166
  ## Documentation
156
167
 
157
- Documentation along with a quick start guide can be found in the [docs/](/docs) directory.
168
+ Documentation along with a quick start guide can be found on the [docs website](https://docs.tinygrad.org/) built from the [docs/](/docs) directory.
158
169
 
159
170
  ### Quick example comparing to PyTorch
160
171
 
@@ -193,13 +204,14 @@ We'll start with what will get your PR closed with a pointer to this section:
193
204
  - All docs and whitespace changes will be closed unless you are a well-known contributor. The people writing the docs should be those who know the codebase the absolute best. People who have not demonstrated that shouldn't be messing with docs. Whitespace changes are both useless *and* carry a risk of introducing bugs.
194
205
  - Anything you claim is a "speedup" must be benchmarked. In general, the goal is simplicity, so even if your PR makes things marginally faster, you have to consider the tradeoff with maintainablity and readablity.
195
206
  - In general, the code outside the core `tinygrad/` folder is not well tested, so unless the current code there is broken, you shouldn't be changing it.
207
+ - If your PR looks "complex", is a big diff, or adds lots of lines, it won't be reviewed or merged. Consider breaking it up into smaller PRs that are individually clear wins. A common pattern I see is prerequisite refactors before adding new functionality. If you can (cleanly) refactor to the point that the feature is a 3 line change, this is great, and something easy for us to review.
196
208
 
197
209
  Now, what we want:
198
210
 
199
211
  - Bug fixes (with a regression test) are great! This library isn't 1.0 yet, so if you stumble upon a bug, fix it, write a test, and submit a PR, this is valuable work.
200
212
  - Solving bounties! tinygrad [offers cash bounties](https://docs.google.com/spreadsheets/d/1WKHbT-7KOgjEawq5h5Ic1qUWzpfAzuD_J06N1JwOCGs/edit?usp=sharing) for certain improvements to the library. All new code should be high quality and well tested.
201
213
  - Features. However, if you are adding a feature, consider the line tradeoff. If it's 3 lines, there's less of a bar of usefulness it has to meet over something that's 30 or 300 lines. All features must have regression tests. In general with no other constraints, your feature's API should match torch or numpy.
202
- - Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win.
214
+ - Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win. Refactors should pass [process replay](#process-replay-tests).
203
215
  - Tests/fuzzers. If you can add tests that are non brittle, they are welcome. We have some fuzzers in here too, and there's a plethora of bugs that can be found with them and by improving them. Finding bugs, even writing broken tests (that should pass) with `@unittest.expectedFailure` is great. This is how we make progress.
204
216
  - Dead code removal from core `tinygrad/` folder. We don't care about the code in extra, but removing dead code from the core library is great. Less for new people to read and be confused by.
205
217
 
@@ -215,3 +227,9 @@ python3 -m pip install -e '.[testing]' # install extra deps for testing
215
227
  python3 test/test_ops.py # just the ops tests
216
228
  python3 -m pytest test/ # whole test suite
217
229
  ```
230
+
231
+ #### Process replay tests
232
+
233
+ [Process replay](https://github.com/tinygrad/tinygrad/blob/master/test/external/process_replay/process_replay.py) detects changes in the generated kernels of CI tests by comparing them against tinygrad master. If your PR is a refactor or speedup without any expected behavior change, it should include a green process replay pass to get merged.
234
+
235
+ You can enable process replay by adding [run_process_replay] to your PR title. [example](https://github.com/tinygrad/tinygrad/pull/4995). Note that you should keep your branch up-to-date with master.
@@ -1,12 +1,15 @@
1
1
  <div align="center">
2
2
 
3
- [![logo](https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png)](https://tinygrad.org)
3
+ <picture>
4
+ <source media="(prefers-color-scheme: light)" srcset="/docs/logo_tiny_light.svg">
5
+ <img alt="tiny corp logo" src="/docs/logo_tiny_dark.svg" width="50%" height="50%">
6
+ </picture>
4
7
 
5
8
  tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) and [karpathy/micrograd](https://github.com/karpathy/micrograd). Maintained by [tiny corp](https://tinygrad.org).
6
9
 
7
10
  <h3>
8
11
 
9
- [Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](/docs) | [Examples](/examples) | [Showcase](/docs/showcase.md) | [Discord](https://discord.gg/ZjZadyC7PK)
12
+ [Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](https://docs.tinygrad.org/) | [Discord](https://discord.gg/ZjZadyC7PK)
10
13
 
11
14
  </h3>
12
15
 
@@ -76,17 +79,15 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers
76
79
 
77
80
  tinygrad already supports numerous accelerators, including:
78
81
 
79
- - [x] [CPU](tinygrad/runtime/ops_cpu.py)
80
82
  - [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
81
- - [x] [C Code (Clang)](tinygrad/runtime/ops_clang.py)
83
+ - [x] [CLANG (C Code)](tinygrad/runtime/ops_clang.py)
82
84
  - [x] [LLVM](tinygrad/runtime/ops_llvm.py)
83
85
  - [x] [METAL](tinygrad/runtime/ops_metal.py)
84
86
  - [x] [CUDA](tinygrad/runtime/ops_cuda.py)
85
- - [x] [PyTorch](tinygrad/runtime/ops_torch.py)
86
- - [x] [HIP](tinygrad/runtime/ops_hip.py)
87
+ - [x] [AMD](tinygrad/runtime/ops_amd.py)
88
+ - [x] [NV](tinygrad/runtime/ops_nv.py)
87
89
 
88
90
  And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
89
- More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md).
90
91
 
91
92
  ## Installation
92
93
 
@@ -108,7 +109,7 @@ python3 -m pip install git+https://github.com/tinygrad/tinygrad.git
108
109
 
109
110
  ## Documentation
110
111
 
111
- Documentation along with a quick start guide can be found in the [docs/](/docs) directory.
112
+ Documentation along with a quick start guide can be found on the [docs website](https://docs.tinygrad.org/) built from the [docs/](/docs) directory.
112
113
 
113
114
  ### Quick example comparing to PyTorch
114
115
 
@@ -147,13 +148,14 @@ We'll start with what will get your PR closed with a pointer to this section:
147
148
  - All docs and whitespace changes will be closed unless you are a well-known contributor. The people writing the docs should be those who know the codebase the absolute best. People who have not demonstrated that shouldn't be messing with docs. Whitespace changes are both useless *and* carry a risk of introducing bugs.
148
149
  - Anything you claim is a "speedup" must be benchmarked. In general, the goal is simplicity, so even if your PR makes things marginally faster, you have to consider the tradeoff with maintainablity and readablity.
149
150
  - In general, the code outside the core `tinygrad/` folder is not well tested, so unless the current code there is broken, you shouldn't be changing it.
151
+ - If your PR looks "complex", is a big diff, or adds lots of lines, it won't be reviewed or merged. Consider breaking it up into smaller PRs that are individually clear wins. A common pattern I see is prerequisite refactors before adding new functionality. If you can (cleanly) refactor to the point that the feature is a 3 line change, this is great, and something easy for us to review.
150
152
 
151
153
  Now, what we want:
152
154
 
153
155
  - Bug fixes (with a regression test) are great! This library isn't 1.0 yet, so if you stumble upon a bug, fix it, write a test, and submit a PR, this is valuable work.
154
156
  - Solving bounties! tinygrad [offers cash bounties](https://docs.google.com/spreadsheets/d/1WKHbT-7KOgjEawq5h5Ic1qUWzpfAzuD_J06N1JwOCGs/edit?usp=sharing) for certain improvements to the library. All new code should be high quality and well tested.
155
157
  - Features. However, if you are adding a feature, consider the line tradeoff. If it's 3 lines, there's less of a bar of usefulness it has to meet over something that's 30 or 300 lines. All features must have regression tests. In general with no other constraints, your feature's API should match torch or numpy.
156
- - Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win.
158
+ - Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win. Refactors should pass [process replay](#process-replay-tests).
157
159
  - Tests/fuzzers. If you can add tests that are non brittle, they are welcome. We have some fuzzers in here too, and there's a plethora of bugs that can be found with them and by improving them. Finding bugs, even writing broken tests (that should pass) with `@unittest.expectedFailure` is great. This is how we make progress.
158
160
  - Dead code removal from core `tinygrad/` folder. We don't care about the code in extra, but removing dead code from the core library is great. Less for new people to read and be confused by.
159
161
 
@@ -169,3 +171,9 @@ python3 -m pip install -e '.[testing]' # install extra deps for testing
169
171
  python3 test/test_ops.py # just the ops tests
170
172
  python3 -m pytest test/ # whole test suite
171
173
  ```
174
+
175
+ #### Process replay tests
176
+
177
+ [Process replay](https://github.com/tinygrad/tinygrad/blob/master/test/external/process_replay/process_replay.py) detects changes in the generated kernels of CI tests by comparing them against tinygrad master. If your PR is a refactor or speedup without any expected behavior change, it should include a green process replay pass to get merged.
178
+
179
+ You can enable process replay by adding [run_process_replay] to your PR title. [example](https://github.com/tinygrad/tinygrad/pull/4995). Note that you should keep your branch up-to-date with master.
@@ -8,19 +8,19 @@ with open(directory / 'README.md', encoding='utf-8') as f:
8
8
  long_description = f.read()
9
9
 
10
10
  setup(name='tinygrad',
11
- version='0.8.0',
11
+ version='0.9.1',
12
12
  description='You like pytorch? You like micrograd? You love tinygrad! <3',
13
13
  author='George Hotz',
14
14
  license='MIT',
15
15
  long_description=long_description,
16
16
  long_description_content_type='text/markdown',
17
- packages = ['tinygrad', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer',
18
- 'tinygrad.runtime', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.features'],
17
+ packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer', 'tinygrad.engine',
18
+ 'tinygrad.runtime', 'tinygrad.runtime.driver', 'tinygrad.runtime.graph', 'tinygrad.shape'],
19
19
  classifiers=[
20
20
  "Programming Language :: Python :: 3",
21
21
  "License :: OSI Approved :: MIT License"
22
22
  ],
23
- install_requires=["numpy", "tqdm", "gpuctypes",
23
+ install_requires=["numpy",
24
24
  "pyobjc-framework-Metal; platform_system=='Darwin'",
25
25
  "pyobjc-framework-libdispatch; platform_system=='Darwin'"],
26
26
  python_requires='>=3.8',
@@ -36,15 +36,17 @@ setup(name='tinygrad',
36
36
  "ruff",
37
37
  "types-tqdm",
38
38
  ],
39
+ #'mlperf': ["mlperf-logging @ git+https://github.com/mlperf/logging.git@4.0.0-rc2"],
39
40
  'testing': [
40
41
  "torch",
41
42
  "pillow",
42
43
  "pytest",
43
44
  "pytest-xdist",
44
- "onnx==1.15.0",
45
+ "onnx==1.16.0",
45
46
  "onnx2torch",
46
47
  "opencv-python",
47
48
  "tabulate",
49
+ "tqdm",
48
50
  "safetensors",
49
51
  "transformers",
50
52
  "sentencepiece",
@@ -52,6 +54,19 @@ setup(name='tinygrad',
52
54
  "librosa",
53
55
  "networkx",
54
56
  "hypothesis",
57
+ "nibabel",
58
+ "bottle",
59
+ ],
60
+ 'docs': [
61
+ "mkdocs-material",
62
+ "mkdocstrings[python]",
63
+ "markdown-callouts",
64
+ "markdown-exec[ansi]",
65
+ "black"
66
+ ],
67
+ 'testing_tf': [
68
+ "tensorflow==2.15.1",
69
+ "tensorflow_addons",
55
70
  ]
56
71
  },
57
72
  include_package_data=True)
@@ -0,0 +1,19 @@
1
+ import unittest
2
+ from tinygrad import Tensor, GlobalCounters
3
+ from tinygrad.helpers import Context
4
+
5
+ class TestArange(unittest.TestCase):
6
+ def _get_flops(self, N):
7
+ GlobalCounters.reset()
8
+ with Context(NOOPT=1):
9
+ Tensor.arange(N).realize()
10
+ return GlobalCounters.global_ops
11
+
12
+ def test_complexity(self):
13
+ f1 = self._get_flops(256)
14
+ f2 = self._get_flops(2560)
15
+ print(f"{f1=}, {f2=}")
16
+ assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
17
+
18
+ if __name__ == "__main__":
19
+ unittest.main()
@@ -0,0 +1,378 @@
1
+ #!/usr/bin/env python
2
+ import unittest
3
+ import numpy as np
4
+ from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
5
+
6
+ N = 200 # has to be bigger than the cache to fail
7
+
8
+ class TestAssign(unittest.TestCase):
9
+ def test_simple_assignment(self):
10
+ a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
11
+ b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
12
+ a.realize()
13
+ b.realize()
14
+ ba1 = a.lazydata.base.realized
15
+ bb1 = b.lazydata.base.realized
16
+ a += b
17
+ a.realize()
18
+ ba2 = a.lazydata.base.realized
19
+ assert ba1 == ba2 and ba1 != bb1
20
+ np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
21
+
22
+ def test_assign_zeros_good(self):
23
+ a = Tensor.zeros(10,10).contiguous()
24
+ a.assign(Tensor.ones(10,10))
25
+ b = Tensor.zeros(10,10).contiguous()
26
+ a.realize()
27
+ np.testing.assert_allclose(b.numpy(), 0)
28
+
29
+ def test_assign_zeros(self):
30
+ a = Tensor.zeros(10,10).contiguous()
31
+ b = Tensor.zeros(10,10).contiguous()
32
+ a.assign(Tensor.ones(10,10))
33
+ a.realize()
34
+ np.testing.assert_allclose(b.numpy(), 0)
35
+
36
+ def test_assign_add(self):
37
+ def f(x):
38
+ x += 1
39
+ x.realize()
40
+ x = Tensor([0])
41
+ f(x)
42
+ assert x.item() == 1
43
+
44
+ def test_assign_add_twice(self):
45
+ # NOTE: this has two kernels
46
+ def f(x):
47
+ x += 1
48
+ x += 1
49
+ x.realize()
50
+ x = Tensor([0])
51
+ f(x)
52
+ assert x.item() == 2
53
+
54
+ def test_assign_add_double(self):
55
+ def f(x):
56
+ x += 1
57
+ x.realize()
58
+ x = Tensor([0])
59
+ f(x)
60
+ assert (out:=x.item()) == 1, f"expected 1, got {out}"
61
+ x = Tensor([0])
62
+ f(x)
63
+ assert (out:=x.item()) == 1, f"expected 1, got {out}"
64
+
65
+ def test_assign_add_jit(self):
66
+ @TinyJit
67
+ def f(x):
68
+ x += 1
69
+ x.realize()
70
+ x = Tensor([0])
71
+ for _ in range(5): f(x)
72
+ assert x.item() == 5
73
+
74
+ def test_assign_add_jit_other(self):
75
+ @TinyJit
76
+ def f(x):
77
+ x += 1
78
+ x.realize()
79
+ x = Tensor([0])
80
+ for _ in range(5): f(x)
81
+ assert x.item() == 5
82
+
83
+ y = Tensor([0])
84
+ for _ in range(4): f(y)
85
+ assert y.item() == 4
86
+
87
+ def test_assign_other_jit(self):
88
+ @TinyJit
89
+ def f(x, a):
90
+ x.assign(a)
91
+ x.realize()
92
+ x = Tensor([0])
93
+ for i in range(1, 6):
94
+ f(x, x.full_like(i).contiguous()) # const would be implicitly folded without contiguous
95
+ assert x.item() == i
96
+
97
+ def test_assign_add_other_jit(self):
98
+ @TinyJit
99
+ def f(x, a):
100
+ x += a
101
+ x.realize()
102
+ x = Tensor([0])
103
+ a = 0
104
+ for i in range(1, 6):
105
+ a += i
106
+ f(x, x.full_like(i).contiguous())
107
+ assert x.item() == a
108
+
109
+ def test_assign_changes(self):
110
+ a = Tensor.ones(4).contiguous().realize()
111
+ old_a = a
112
+ a.assign(Tensor.full((4,), 2.).contiguous())
113
+ # NOTE: old_a is now 2, and this would match the behavior of pytorch
114
+ new = a + old_a
115
+ np.testing.assert_allclose(new.numpy(), 4)
116
+
117
+ def test_assign_diamond_cycle(self):
118
+ # NOTE: should *not* raise AssertionError from numpy
119
+ with self.assertRaisesRegex(RuntimeError, "cycle"):
120
+ a = Tensor.ones(4).contiguous().realize()
121
+ times_a = a*3
122
+ a.assign(Tensor.full((4,), 2.).contiguous())
123
+ new = a + (times_a-1)
124
+ np.testing.assert_allclose(new.numpy(), 4)
125
+
126
+ def test_assign_diamond_contiguous_cycle(self):
127
+ with self.assertRaisesRegex(RuntimeError, "cycle"):
128
+ a = Tensor.ones(4).contiguous().realize()
129
+ times_a = a*3
130
+ a.assign(Tensor.full((4,), 2.))
131
+ new = a.contiguous() + times_a-1
132
+ np.testing.assert_allclose(new.numpy(), 4)
133
+
134
+ def test_assign_diamond_possible(self):
135
+ a = Tensor.ones(4).contiguous().realize()
136
+ times_a = a*3
137
+ a.assign(Tensor.full((4,), 2.))
138
+ new = a + (times_a-1).contiguous()
139
+ np.testing.assert_allclose(new.numpy(), 4)
140
+
141
+ def test_assign_diamond_possible_contiguous(self):
142
+ a = Tensor.ones(4).contiguous().realize()
143
+ times_a = a*3
144
+ a.assign(Tensor.full((4,), 2.).contiguous())
145
+ new = a + (times_a-1).contiguous()
146
+ np.testing.assert_allclose(new.numpy(), 4)
147
+
148
+ def test_assign_diamond_both_contiguous(self):
149
+ a = Tensor.ones(4).contiguous().realize()
150
+ times_a = a*3
151
+ a.assign(Tensor.full((4,), 2.))
152
+ new = a.contiguous() + (times_a-1).contiguous()
153
+ np.testing.assert_allclose(new.numpy(), 4)
154
+
155
+ def test_assign_diamond_alt(self):
156
+ a = Tensor.ones(4).contiguous().realize()
157
+ a.assign(Tensor.full((4,), 2.).contiguous())
158
+ times_a = a*3
159
+ new = a + times_a
160
+ np.testing.assert_allclose(new.numpy(), 8)
161
+
162
+ def test_double_assign(self):
163
+ a = Tensor.ones(4).contiguous().realize()
164
+ a += 1
165
+ a += 1
166
+ np.testing.assert_allclose(a.numpy(), 3)
167
+
168
+ def test_crossover_assign(self):
169
+ a = Tensor.full((4,), 2).contiguous().realize()
170
+ b = Tensor.full((4,), 3).contiguous().realize()
171
+ a += b
172
+ b += a
173
+ Tensor.realize(a,b)
174
+ np.testing.assert_allclose(a.numpy(), 5)
175
+ np.testing.assert_allclose(b.numpy(), 8)
176
+
177
+ def test_assign_double_diamond(self):
178
+ a = Tensor.full((4,), 2).contiguous().realize()
179
+ b = Tensor.full((4,), 3).contiguous().realize()
180
+ a_prev = a*4
181
+ b_prev = b+3
182
+ b += a_prev.contiguous()
183
+ a += b_prev.contiguous()
184
+ Tensor.realize(a, b)
185
+ np.testing.assert_equal(b.numpy(), 11)
186
+ np.testing.assert_equal(a.numpy(), 8)
187
+
188
+ def test_assign_double_diamond_reduce(self):
189
+ a0 = Tensor.full((16, 16), 10).contiguous().realize()
190
+ a1 = Tensor.full((16, 16), 20).contiguous().realize()
191
+ b0 = Tensor.full((16, ), 1).contiguous().realize()
192
+ b1 = Tensor.full((16, ), 2).contiguous().realize()
193
+
194
+ r0 = (a0 - b1.contiguous()).sum(1)
195
+ r1 = (a1 - b0.contiguous()).sum(1)
196
+ b0.assign(r0 * b0)
197
+ b1.assign(r1 * b1)
198
+ Tensor.realize(b0, b1)
199
+ np.testing.assert_equal(b0.numpy(), 128)
200
+ np.testing.assert_equal(b1.numpy(), 608)
201
+
202
+ def test_crossunder_assign(self):
203
+ # NOTE: should *not* raise AssertionError from numpy
204
+ with self.assertRaisesRegex(RuntimeError, "cycle"):
205
+ a = Tensor.full((4,), 2).contiguous().realize()
206
+ b = Tensor.full((4,), 3).contiguous().realize()
207
+ c = a+9
208
+ a += b
209
+ b += c
210
+ Tensor.realize(a,b)
211
+ np.testing.assert_allclose(a.numpy(), 2+3)
212
+ np.testing.assert_allclose(b.numpy(), 3+2+9)
213
+
214
+ def test_assign_kv_cache(self):
215
+ bsz, max_context = 2, 8
216
+
217
+ class Attn:
218
+ @TinyJit
219
+ def __call__(self, xk:Tensor, start_pos:Variable):
220
+ seqlen = xk.shape[1]
221
+ if not hasattr(self, "cache_k"):
222
+ self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).contiguous()
223
+ keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
224
+ self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
225
+
226
+ attn = Attn()
227
+ xk = Tensor.ones(bsz, 3, 1, 1).contiguous()
228
+ attn(xk, 0)
229
+ for i in range(3,6):
230
+ # copied from LLaMA
231
+ start_pos = Variable("start_pos", 1, max_context).bind(i)
232
+ xk = Tensor.ones(bsz, 1, 1, 1).contiguous()
233
+ attn(xk, start_pos)
234
+
235
+ out = attn.cache_k.flatten().numpy()
236
+ np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
237
+
238
+ def test_assign_contiguous(self):
239
+ b = Tensor.rand(4,4).realize()
240
+ a = (Tensor.rand(4,4).realize() + 1)
241
+ kc = GlobalCounters.kernel_count
242
+ b.assign(a.contiguous()).realize()
243
+ assert GlobalCounters.kernel_count - kc == 2
244
+
245
+ def test_assign_contiguous_permute(self):
246
+ b = Tensor.rand(4,4).realize()
247
+ a = (Tensor.rand(4,4).realize() + 1).permute((1,0))
248
+ kc = GlobalCounters.kernel_count
249
+ b.assign(a.contiguous()).realize()
250
+ assert GlobalCounters.kernel_count - kc == 2
251
+
252
+ def test_permuted_assignment(self):
253
+ a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
254
+ b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
255
+ a.realize()
256
+ b.realize()
257
+ ba1 = a.lazydata.base.realized
258
+ bb1 = b.lazydata.base.realized
259
+ with self.assertRaises((RuntimeError, AssertionError)):
260
+ a = a.permute(1,0)
261
+ a += b
262
+ a.realize()
263
+ ba2 = a.lazydata.base.realized
264
+ assert ba1 != ba2 and ba1 != bb1
265
+ np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
266
+
267
+ def test_post_permuted_assignment(self):
268
+ a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
269
+ b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
270
+ a.realize()
271
+ b.realize()
272
+ #GlobalCounters.cache = []
273
+ ba1 = a.lazydata.base.realized # noqa: F841
274
+ bb1 = b.lazydata.base.realized # noqa: F841
275
+ with self.assertRaisesRegex(RuntimeError, "contiguous"):
276
+ a.assign(a.permute(1,0) + b) # this should not work!
277
+ a.realize()
278
+ ba2 = a.lazydata.base.realized # noqa: F841
279
+ # NOTE: don't test that it's assigned
280
+ #assert ba1 == ba2 and ba1 != bb1
281
+ np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
282
+
283
+ def test_simple_assignment_multioutput(self):
284
+ a = Tensor.randn(32, 32).realize()
285
+ b = Tensor.full((32, ), 1.).contiguous().realize()
286
+ c = Tensor.full((32, ), 2.).contiguous().realize()
287
+ d = Tensor.full((32, ), 3.).contiguous().realize()
288
+
289
+ r = a.sum(axis=1)
290
+ b.assign(r + b)
291
+ c.assign(r + c)
292
+ d.assign(r + d)
293
+
294
+ kc = GlobalCounters.kernel_count
295
+ Tensor.realize(b, c, d)
296
+ assert GlobalCounters.kernel_count - kc == 1
297
+ np.testing.assert_allclose(b.numpy(), a.sum(1).numpy()+1)
298
+ np.testing.assert_allclose(c.numpy(), a.sum(1).numpy()+2)
299
+ np.testing.assert_allclose(d.numpy(), a.sum(1).numpy()+3)
300
+
301
+ # NOTE: if the assign target is read/write in a single kernel, it should be contiguous
302
+
303
+ def test_permuted_assignment_correct(self):
304
+ a = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
305
+ b = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
306
+ # TODO: scheduler limitation, should NOT raise AssertionError from numpy.
307
+ with self.assertRaisesRegex(RuntimeError, "contiguous"):
308
+ a = a.permute(1, 0)
309
+ new_val = a + b
310
+ a.assign(new_val)
311
+ np.testing.assert_equal(a.numpy(), np.arange(4 * 4).reshape(4, 4).transpose(1, 0) + np.arange(4 * 4).reshape(4, 4))
312
+
313
+ def test_permuted_reduceop_child_dual_use(self):
314
+ a = Tensor.randn(32, 32, 32).realize()
315
+ b = Tensor.full((32, 32), 1.).contiguous().realize()
316
+ with self.assertRaisesRegex(RuntimeError, "contiguous"):
317
+ r = a.sum(axis=1)
318
+ b.assign(r + b.permute(1, 0))
319
+ b.realize()
320
+
321
+ def test_permuted_reduceop_multioutput_dual_use(self):
322
+ a = Tensor.randn(32, 32, 32).realize()
323
+ b = Tensor.full((32, 32), 1.).contiguous().realize()
324
+ c = Tensor.full((32, 32), 2.).contiguous().realize()
325
+
326
+ with self.assertRaisesRegex(RuntimeError, "contiguous"):
327
+ r = a.sum(axis=1)
328
+ b_perm = b.permute(1, 0)
329
+ b.assign(r + b)
330
+ c.assign(r + b_perm)
331
+ Tensor.realize(b, c)
332
+
333
+ def test_permuted_reduceop_multioutput_dual_use_possible(self):
334
+ a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize()
335
+ b = Tensor.arange(32 * 32).reshape(32, 32).realize()
336
+ c = Tensor.arange(32 * 32).reshape(32, 32).realize()
337
+
338
+ kc = GlobalCounters.kernel_count
339
+ r = a.sum(axis=1)
340
+ b_perm = b.permute(1, 0)
341
+ b.assign(r + b)
342
+ c.assign(r + b_perm.contiguous())
343
+ Tensor.realize(b, c)
344
+ assert GlobalCounters.kernel_count - kc == 2
345
+ np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32))
346
+ np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0))
347
+
348
+ def test_permuted_assignment_masked_view_possible(self):
349
+ a = Tensor.ones(4, 4).contiguous().realize()
350
+ b = a.shrink((None, (0, 2))).pad((None, (0, 2)), 2)
351
+ a.assign(a + b)
352
+ kc = GlobalCounters.kernel_count
353
+ a.realize()
354
+ assert GlobalCounters.kernel_count - kc == 1
355
+ np.testing.assert_equal(a.numpy(), np.ones((4, 4))+np.pad(np.ones((4, 4))[:, 0:2], ((0, 0), (0, 2)), constant_values=2))
356
+
357
+ def test_permuted_assignment_masked_view_not_contiguous(self):
358
+ a = Tensor.ones(4, 4).contiguous().realize()
359
+ with self.assertRaisesRegex(RuntimeError, "contiguous"):
360
+ b = a.shrink((None, (0, 2))).pad((None, (0, 2)), 2).permute(1, 0)
361
+ a.assign(a + b)
362
+ a.realize()
363
+
364
+ # TODO: is there a way to sneak in a permute such that it returns the wrong answer?
365
+
366
+ @unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
367
+ def test_cast_assignment(self):
368
+ a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
369
+ a.realize()
370
+ oba1 = a.lazydata.base.output_buffer
371
+ a.assign(a.cast(dtypes.int32).realize())
372
+ a.realize()
373
+ oba2 = a.lazydata.base.output_buffer
374
+ assert oba1 is None and oba2 is None
375
+ np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
376
+
377
+ if __name__ == "__main__":
378
+ unittest.main()