tinygrad 0.10.0__tar.gz → 0.10.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. {tinygrad-0.10.0 → tinygrad-0.10.1}/PKG-INFO +20 -9
  2. {tinygrad-0.10.0 → tinygrad-0.10.1}/README.md +5 -4
  3. {tinygrad-0.10.0 → tinygrad-0.10.1}/setup.py +7 -5
  4. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_arange.py +14 -10
  5. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_assign.py +17 -11
  6. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_const_folding.py +42 -7
  7. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_conv_shapetracker.py +4 -6
  8. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_dtype.py +38 -15
  9. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_dtype_alu.py +42 -9
  10. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_fusion_op.py +8 -9
  11. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_gc.py +24 -8
  12. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_graph.py +2 -3
  13. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_hcq.py +105 -77
  14. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_image_dtype.py +63 -8
  15. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_jit.py +101 -2
  16. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_kernel_cache.py +11 -9
  17. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_linearizer.py +133 -134
  18. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_linearizer_dumb.py +9 -7
  19. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_linearizer_failures.py +93 -113
  20. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_linearizer_overflows.py +9 -9
  21. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_metal.py +3 -5
  22. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_multitensor.py +166 -52
  23. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_net_speed.py +1 -1
  24. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_nn.py +40 -34
  25. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_ops.py +416 -73
  26. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_optim.py +1 -1
  27. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_pickle.py +59 -5
  28. tinygrad-0.10.1/test/test_profiler.py +163 -0
  29. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_randomness.py +2 -1
  30. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_renderer_failures.py +13 -6
  31. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_schedule.py +838 -173
  32. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_search.py +11 -10
  33. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_setitem.py +24 -9
  34. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_speed_v_torch.py +0 -1
  35. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_subbuffer.py +19 -1
  36. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_symbolic_ops.py +0 -2
  37. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_symbolic_shapetracker.py +12 -1
  38. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_tensor.py +133 -55
  39. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_tensor_data.py +14 -0
  40. tinygrad-0.10.0/test/test_lazybuffer.py → tinygrad-0.10.1/test/test_tensor_uop.py +32 -49
  41. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_tiny.py +42 -8
  42. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_transcendental.py +4 -4
  43. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_uop_graph.py +66 -64
  44. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_uops.py +225 -43
  45. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_uops_stats.py +32 -18
  46. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_winograd.py +11 -4
  47. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_zero_copy.py +1 -1
  48. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/codegen/kernel.py +114 -172
  49. tinygrad-0.10.1/tinygrad/codegen/linearize.py +225 -0
  50. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/codegen/lowerer.py +30 -35
  51. tinygrad-0.10.0/tinygrad/codegen/uopgraph.py → tinygrad-0.10.1/tinygrad/codegen/rewriter.py +69 -59
  52. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/codegen/transcendental.py +12 -13
  53. tinygrad-0.10.1/tinygrad/device.py +344 -0
  54. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/dtype.py +28 -26
  55. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/engine/jit.py +80 -63
  56. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/engine/memory.py +4 -5
  57. tinygrad-0.10.1/tinygrad/engine/multi.py +162 -0
  58. tinygrad-0.10.1/tinygrad/engine/realize.py +168 -0
  59. tinygrad-0.10.1/tinygrad/engine/schedule.py +486 -0
  60. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/engine/search.py +40 -44
  61. tinygrad-0.10.1/tinygrad/gradient.py +70 -0
  62. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/helpers.py +77 -58
  63. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/nn/__init__.py +30 -32
  64. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/nn/datasets.py +1 -2
  65. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/nn/optim.py +22 -26
  66. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/nn/state.py +89 -64
  67. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/ops.py +562 -446
  68. tinygrad-0.10.1/tinygrad/renderer/__init__.py +132 -0
  69. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/renderer/cstyle.py +70 -84
  70. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/renderer/llvmir.py +32 -20
  71. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/renderer/ptx.py +79 -99
  72. tinygrad-0.10.1/tinygrad/renderer/wgsl.py +87 -0
  73. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  74. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/comgr.py +2 -0
  75. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/kfd.py +4 -3
  76. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/kgsl.py +1 -1
  77. tinygrad-0.10.1/tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  78. tinygrad-0.10.1/tinygrad/runtime/autogen/llvm.py +11379 -0
  79. tinygrad-0.10.1/tinygrad/runtime/autogen/vfio.py +891 -0
  80. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/graph/cuda.py +8 -9
  81. tinygrad-0.10.1/tinygrad/runtime/graph/hcq.py +205 -0
  82. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/graph/metal.py +19 -21
  83. tinygrad-0.10.1/tinygrad/runtime/ops_amd.py +632 -0
  84. tinygrad-0.10.1/tinygrad/runtime/ops_clang.py +22 -0
  85. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_cloud.py +34 -34
  86. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_cuda.py +30 -27
  87. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_disk.py +62 -63
  88. tinygrad-0.10.1/tinygrad/runtime/ops_dsp.py +272 -0
  89. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_gpu.py +30 -30
  90. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_hip.py +29 -31
  91. tinygrad-0.10.1/tinygrad/runtime/ops_llvm.py +56 -0
  92. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_metal.py +93 -73
  93. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_npy.py +2 -2
  94. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_nv.py +232 -270
  95. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_python.py +51 -46
  96. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/ops_qcom.py +129 -157
  97. tinygrad-0.10.1/tinygrad/runtime/ops_webgpu.py +63 -0
  98. tinygrad-0.10.1/tinygrad/runtime/support/allocator.py +94 -0
  99. tinygrad-0.10.1/tinygrad/runtime/support/am/amdev.py +384 -0
  100. tinygrad-0.10.1/tinygrad/runtime/support/am/ip.py +463 -0
  101. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/support/compiler_cuda.py +4 -2
  102. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/support/elf.py +26 -4
  103. tinygrad-0.10.1/tinygrad/runtime/support/hcq.py +469 -0
  104. tinygrad-0.10.1/tinygrad/runtime/support/llvm.py +32 -0
  105. tinygrad-0.10.1/tinygrad/shape/__init__.py +0 -0
  106. tinygrad-0.10.1/tinygrad/shape/shapetracker.py +142 -0
  107. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/shape/view.py +103 -138
  108. tinygrad-0.10.1/tinygrad/spec.py +154 -0
  109. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/tensor.py +744 -496
  110. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad.egg-info/PKG-INFO +20 -9
  111. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad.egg-info/SOURCES.txt +15 -7
  112. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad.egg-info/requires.txt +6 -4
  113. tinygrad-0.10.0/test/test_profiler.py +0 -221
  114. tinygrad-0.10.0/test/test_viz.py +0 -93
  115. tinygrad-0.10.0/tinygrad/codegen/linearize.py +0 -95
  116. tinygrad-0.10.0/tinygrad/device.py +0 -221
  117. tinygrad-0.10.0/tinygrad/engine/lazy.py +0 -228
  118. tinygrad-0.10.0/tinygrad/engine/realize.py +0 -217
  119. tinygrad-0.10.0/tinygrad/engine/schedule.py +0 -419
  120. tinygrad-0.10.0/tinygrad/function.py +0 -212
  121. tinygrad-0.10.0/tinygrad/multi.py +0 -177
  122. tinygrad-0.10.0/tinygrad/renderer/__init__.py +0 -89
  123. tinygrad-0.10.0/tinygrad/runtime/graph/clang.py +0 -39
  124. tinygrad-0.10.0/tinygrad/runtime/graph/hcq.py +0 -200
  125. tinygrad-0.10.0/tinygrad/runtime/ops_amd.py +0 -471
  126. tinygrad-0.10.0/tinygrad/runtime/ops_clang.py +0 -35
  127. tinygrad-0.10.0/tinygrad/runtime/ops_dsp.py +0 -181
  128. tinygrad-0.10.0/tinygrad/runtime/ops_llvm.py +0 -51
  129. tinygrad-0.10.0/tinygrad/runtime/support/hcq.py +0 -539
  130. tinygrad-0.10.0/tinygrad/shape/shapetracker.py +0 -111
  131. {tinygrad-0.10.0 → tinygrad-0.10.1}/LICENSE +0 -0
  132. {tinygrad-0.10.0 → tinygrad-0.10.1}/setup.cfg +0 -0
  133. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_compile_failures.py +0 -0
  134. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_conv.py +0 -0
  135. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_copy_speed.py +0 -0
  136. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_device_speed.py +0 -0
  137. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_fuzz_shape_ops.py +0 -0
  138. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_masked_st.py +0 -0
  139. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_method_cache.py +0 -0
  140. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_ocl.py +0 -0
  141. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_rearrange_einops.py +0 -0
  142. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_sample.py +0 -0
  143. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_specific_conv.py +0 -0
  144. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_symbolic_jit.py +0 -0
  145. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_tensor_variable.py +0 -0
  146. {tinygrad-0.10.0 → tinygrad-0.10.1}/test/test_to_numpy.py +0 -0
  147. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/__init__.py +0 -0
  148. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/codegen/__init__.py +0 -0
  149. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/engine/__init__.py +0 -0
  150. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/py.typed +0 -0
  151. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/__init__.py +0 -0
  152. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/adreno.py +0 -0
  153. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/cuda.py +0 -0
  154. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/hip.py +0 -0
  155. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/hsa.py +0 -0
  156. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/io_uring.py +0 -0
  157. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/libc.py +0 -0
  158. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/nv_gpu.py +0 -0
  159. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/nvrtc.py +0 -0
  160. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/opencl.py +0 -0
  161. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/autogen/qcom_dsp.py +0 -0
  162. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/graph/__init__.py +0 -0
  163. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/support/__init__.py +0 -0
  164. {tinygrad-0.10.0/tinygrad/shape → tinygrad-0.10.1/tinygrad/runtime/support/am}/__init__.py +0 -0
  165. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad/runtime/support/compiler_hip.py +0 -0
  166. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad.egg-info/dependency_links.txt +0 -0
  167. {tinygrad-0.10.0 → tinygrad-0.10.1}/tinygrad.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: tinygrad
3
- Version: 0.10.0
3
+ Version: 0.10.1
4
4
  Summary: You like pytorch? You like micrograd? You love tinygrad! <3
5
5
  Author: George Hotz
6
6
  License: MIT
@@ -9,15 +9,13 @@ Classifier: License :: OSI Approved :: MIT License
9
9
  Requires-Python: >=3.10
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
- Provides-Extra: llvm
13
- Requires-Dist: llvmlite; extra == "llvm"
14
12
  Provides-Extra: arm
15
13
  Requires-Dist: unicorn; extra == "arm"
16
14
  Provides-Extra: triton
17
15
  Requires-Dist: triton-nightly>=2.1.0.dev20231014192330; extra == "triton"
18
16
  Provides-Extra: linting
19
17
  Requires-Dist: pylint; extra == "linting"
20
- Requires-Dist: mypy==1.11.2; extra == "linting"
18
+ Requires-Dist: mypy==1.13.0; extra == "linting"
21
19
  Requires-Dist: typing-extensions; extra == "linting"
22
20
  Requires-Dist: pre-commit; extra == "linting"
23
21
  Requires-Dist: ruff; extra == "linting"
@@ -25,6 +23,7 @@ Requires-Dist: types-tqdm; extra == "linting"
25
23
  Provides-Extra: testing
26
24
  Requires-Dist: numpy; extra == "testing"
27
25
  Requires-Dist: torch; extra == "testing"
26
+ Requires-Dist: jax; extra == "testing"
28
27
  Requires-Dist: pillow; extra == "testing"
29
28
  Requires-Dist: pytest; extra == "testing"
30
29
  Requires-Dist: pytest-xdist; extra == "testing"
@@ -44,6 +43,9 @@ Requires-Dist: hypothesis; extra == "testing"
44
43
  Requires-Dist: nibabel; extra == "testing"
45
44
  Requires-Dist: bottle; extra == "testing"
46
45
  Requires-Dist: ggml-python; extra == "testing"
46
+ Requires-Dist: capstone; extra == "testing"
47
+ Provides-Extra: webgpu
48
+ Requires-Dist: wgpu; extra == "webgpu"
47
49
  Provides-Extra: docs
48
50
  Requires-Dist: mkdocs; extra == "docs"
49
51
  Requires-Dist: mkdocs-material; extra == "docs"
@@ -55,6 +57,14 @@ Requires-Dist: numpy; extra == "docs"
55
57
  Provides-Extra: testing-tf
56
58
  Requires-Dist: tensorflow==2.15.1; extra == "testing-tf"
57
59
  Requires-Dist: tensorflow_addons; extra == "testing-tf"
60
+ Dynamic: author
61
+ Dynamic: classifier
62
+ Dynamic: description
63
+ Dynamic: description-content-type
64
+ Dynamic: license
65
+ Dynamic: provides-extra
66
+ Dynamic: requires-python
67
+ Dynamic: summary
58
68
 
59
69
  <div align="center">
60
70
 
@@ -146,6 +156,7 @@ tinygrad already supports numerous accelerators, including:
146
156
  - [x] [AMD](tinygrad/runtime/ops_amd.py)
147
157
  - [x] [NV](tinygrad/runtime/ops_nv.py)
148
158
  - [x] [QCOM](tinygrad/runtime/ops_qcom.py)
159
+ - [x] [WEBGPU](tinygrad/runtime/ops_webgpu.py)
149
160
 
150
161
  And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
151
162
 
@@ -183,8 +194,8 @@ y = Tensor([[2.0,0,-2.0]], requires_grad=True)
183
194
  z = y.matmul(x).sum()
184
195
  z.backward()
185
196
 
186
- print(x.grad.numpy()) # dz/dx
187
- print(y.grad.numpy()) # dz/dy
197
+ print(x.grad.tolist()) # dz/dx
198
+ print(y.grad.tolist()) # dz/dy
188
199
  ```
189
200
 
190
201
  The same thing but in PyTorch:
@@ -196,8 +207,8 @@ y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
196
207
  z = y.matmul(x).sum()
197
208
  z.backward()
198
209
 
199
- print(x.grad.numpy()) # dz/dx
200
- print(y.grad.numpy()) # dz/dy
210
+ print(x.grad.tolist()) # dz/dx
211
+ print(y.grad.tolist()) # dz/dy
201
212
  ```
202
213
 
203
214
  ## Contributing
@@ -88,6 +88,7 @@ tinygrad already supports numerous accelerators, including:
88
88
  - [x] [AMD](tinygrad/runtime/ops_amd.py)
89
89
  - [x] [NV](tinygrad/runtime/ops_nv.py)
90
90
  - [x] [QCOM](tinygrad/runtime/ops_qcom.py)
91
+ - [x] [WEBGPU](tinygrad/runtime/ops_webgpu.py)
91
92
 
92
93
  And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
93
94
 
@@ -125,8 +126,8 @@ y = Tensor([[2.0,0,-2.0]], requires_grad=True)
125
126
  z = y.matmul(x).sum()
126
127
  z.backward()
127
128
 
128
- print(x.grad.numpy()) # dz/dx
129
- print(y.grad.numpy()) # dz/dy
129
+ print(x.grad.tolist()) # dz/dx
130
+ print(y.grad.tolist()) # dz/dy
130
131
  ```
131
132
 
132
133
  The same thing but in PyTorch:
@@ -138,8 +139,8 @@ y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
138
139
  z = y.matmul(x).sum()
139
140
  z.backward()
140
141
 
141
- print(x.grad.numpy()) # dz/dx
142
- print(y.grad.numpy()) # dz/dy
142
+ print(x.grad.tolist()) # dz/dx
143
+ print(y.grad.tolist()) # dz/dy
143
144
  ```
144
145
 
145
146
  ## Contributing
@@ -8,14 +8,14 @@ with open(directory / 'README.md', encoding='utf-8') as f:
8
8
  long_description = f.read()
9
9
 
10
10
  setup(name='tinygrad',
11
- version='0.10.0',
11
+ version='0.10.1',
12
12
  description='You like pytorch? You like micrograd? You love tinygrad! <3',
13
13
  author='George Hotz',
14
14
  license='MIT',
15
15
  long_description=long_description,
16
16
  long_description_content_type='text/markdown',
17
17
  packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer', 'tinygrad.engine',
18
- 'tinygrad.runtime', 'tinygrad.runtime.support', 'tinygrad.runtime.graph', 'tinygrad.shape'],
18
+ 'tinygrad.runtime', 'tinygrad.runtime.support', 'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape'],
19
19
  package_data = {'tinygrad': ['py.typed']},
20
20
  classifiers=[
21
21
  "Programming Language :: Python :: 3",
@@ -24,12 +24,11 @@ setup(name='tinygrad',
24
24
  install_requires=[],
25
25
  python_requires='>=3.10',
26
26
  extras_require={
27
- 'llvm': ["llvmlite"],
28
27
  'arm': ["unicorn"],
29
28
  'triton': ["triton-nightly>=2.1.0.dev20231014192330"],
30
29
  'linting': [
31
30
  "pylint",
32
- "mypy==1.11.2",
31
+ "mypy==1.13.0",
33
32
  "typing-extensions",
34
33
  "pre-commit",
35
34
  "ruff",
@@ -39,6 +38,7 @@ setup(name='tinygrad',
39
38
  'testing': [
40
39
  "numpy",
41
40
  "torch",
41
+ "jax",
42
42
  "pillow",
43
43
  "pytest",
44
44
  "pytest-xdist",
@@ -57,8 +57,10 @@ setup(name='tinygrad',
57
57
  "hypothesis",
58
58
  "nibabel",
59
59
  "bottle",
60
- "ggml-python"
60
+ "ggml-python",
61
+ "capstone"
61
62
  ],
63
+ 'webgpu': ["wgpu"],
62
64
  'docs': [
63
65
  "mkdocs",
64
66
  "mkdocs-material",
@@ -1,11 +1,12 @@
1
1
  import unittest, contextlib
2
2
  import numpy as np
3
- from tinygrad import Tensor, GlobalCounters, dtypes, nn
3
+ from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device
4
4
  from tinygrad.helpers import CI, Context, getenv
5
5
  from tinygrad.engine.realize import run_schedule
6
6
  from tinygrad.codegen.kernel import Opt, OptOps, Kernel, KernelOptError
7
7
  from tinygrad.engine.realize import CompiledRunner, ExecItem
8
8
  from tinygrad.engine.search import get_kernel_actions
9
+ from tinygrad.ops import Ops
9
10
 
10
11
  class TestArange(unittest.TestCase):
11
12
  def _get_flops(self, N, opts=None):
@@ -21,7 +22,7 @@ class TestArange(unittest.TestCase):
21
22
  #print(p.src)
22
23
  ExecItem(CompiledRunner(p), [tt.lazydata.buffer]).run()
23
24
  np.testing.assert_equal(tt.numpy(), np.arange(N))
24
- return p.op_estimate
25
+ return p.estimates.ops
25
26
 
26
27
  def test_complexity(self, opts=None, limit=None):
27
28
  # add 1 to avoid divide by 0. arange is 0 flops now!
@@ -40,7 +41,7 @@ class TestArange(unittest.TestCase):
40
41
  def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1)
41
42
 
42
43
  @unittest.skip("doesn't work yet")
43
- def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, amt=32)])
44
+ def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, arg=32)])
44
45
 
45
46
  def test_all_opts(self, opts=None, exclude=None):
46
47
  k = Kernel(Tensor.arange(256).schedule()[-1].ast)
@@ -58,13 +59,15 @@ class TestArange(unittest.TestCase):
58
59
  self.test_complexity(opts)
59
60
  def test_all_opts_w_local(self):
60
61
  with contextlib.suppress(KernelOptError):
61
- return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, amt=32)])
62
+ return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, arg=32)])
62
63
  def test_all_opts_w_upcast(self): return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4)])
63
- def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
64
+ def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)])
64
65
  def test_all_opts_w_upcast_and_unroll(self):
65
- return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
66
+ return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)])
66
67
 
67
68
  class TestIndexing(unittest.TestCase):
69
+ # update: passing after CAST_BEFORE_VIEW=1 deletion
70
+ # @unittest.expectedFailure
68
71
  def test_arange_2_reduce(self):
69
72
  needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
70
73
  needle[1337] = 1
@@ -86,7 +89,7 @@ class TestIndexing(unittest.TestCase):
86
89
  print("*** indexing ***")
87
90
  with Context(NOOPT=1, FUSE_ARANGE=1):
88
91
  GlobalCounters.reset()
89
- rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1)
92
+ rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumalu(axis=-1, op=Ops.ADD, _include_initial=True).reshape(4, 256, 16384, 1)
90
93
  idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
91
94
  reshape_dataset = dataset.T.reshape(1, 256, 16384, 1).expand(4, 256, 16384, 1)
92
95
  full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1))
@@ -138,7 +141,7 @@ class TestIndexing(unittest.TestCase):
138
141
  np.testing.assert_equal(X.numpy(), 0)
139
142
 
140
143
  @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
141
- def test_index_mnist(self, noopt=1, op_limit=512*784*5):
144
+ def test_index_mnist(self, noopt=1, op_limit=512*784*13):
142
145
  from tinygrad.nn.datasets import mnist
143
146
  X_train, Y_train, _, _ = mnist()
144
147
  with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
@@ -152,12 +155,13 @@ class TestIndexing(unittest.TestCase):
152
155
  @unittest.skip("not ready")
153
156
  def test_index_mnist_opt(self): self.test_index_mnist(0)
154
157
 
155
- @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
158
+ @unittest.skipIf(getenv("PTX") or Device.DEFAULT == "WEBGPU", "broken on ptx and WebGPU for some reason")
156
159
  def test_llama_embedding(self, noopt=1, op_limit=65536):
157
160
  # llama3 is 128256
158
161
  vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
159
162
  emb = nn.Embedding(vocab_size, embed_size)
160
- emb_w = emb.weight.numpy()
163
+ # TODO: why is a new realize needed here
164
+ emb_w = emb.weight.realize().numpy()
161
165
  x = Tensor([1,2,3,4])
162
166
  with Context(NOOPT=noopt, FUSE_ARANGE=1):
163
167
  GlobalCounters.reset()
@@ -2,7 +2,8 @@
2
2
  import unittest
3
3
  import numpy as np
4
4
  from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
5
- from tinygrad.engine.schedule import create_schedule
5
+ from tinygrad.device import is_dtype_supported
6
+ from tinygrad.helpers import temp
6
7
 
7
8
  N = 200 # has to be bigger than the cache to fail
8
9
 
@@ -168,16 +169,6 @@ class TestAssign(unittest.TestCase):
168
169
  a += 1
169
170
  np.testing.assert_allclose(a.numpy(), 3)
170
171
 
171
- # NOTE: this is similar to the resnet failure
172
- #@unittest.expectedFailure
173
- def test_double_assign_alt(self):
174
- a = Tensor.ones(4).contiguous().realize()
175
- b = Tensor([1, 2, 3, 4]).realize().lazydata
176
- a1 = a.lazydata.assign(b)
177
- a2 = a.lazydata.assign(b)
178
- sched = create_schedule([a1, a2])
179
- self.assertEqual(len(sched), 1)
180
-
181
172
  def test_crossover_assign(self):
182
173
  a = Tensor.full((4,), 2).contiguous().realize()
183
174
  b = Tensor.full((4,), 3).contiguous().realize()
@@ -293,6 +284,7 @@ class TestAssign(unittest.TestCase):
293
284
  #assert ba1 == ba2 and ba1 != bb1
294
285
  np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
295
286
 
287
+ @unittest.skip("multi output not supported anymore")
296
288
  def test_simple_assignment_multioutput(self):
297
289
  a = Tensor.randn(32, 32).realize()
298
290
  b = Tensor.full((32, ), 1.).contiguous().realize()
@@ -331,6 +323,7 @@ class TestAssign(unittest.TestCase):
331
323
  b.assign(r + b.permute(1, 0))
332
324
  b.realize()
333
325
 
326
+ @unittest.skip("multi output not supported anymore")
334
327
  def test_permuted_reduceop_multioutput_dual_use(self):
335
328
  a = Tensor.randn(32, 32, 32).realize()
336
329
  b = Tensor.full((32, 32), 1.).contiguous().realize()
@@ -343,6 +336,7 @@ class TestAssign(unittest.TestCase):
343
336
  c.assign(r + b_perm)
344
337
  Tensor.realize(b, c)
345
338
 
339
+ @unittest.skip("multi output not supported anymore")
346
340
  def test_permuted_reduceop_multioutput_dual_use_possible(self):
347
341
  a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize()
348
342
  b = Tensor.arange(32 * 32).reshape(32, 32).realize()
@@ -376,6 +370,14 @@ class TestAssign(unittest.TestCase):
376
370
 
377
371
  # TODO: is there a way to sneak in a permute such that it returns the wrong answer?
378
372
 
373
+ @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
374
+ def test_setitem_half(self):
375
+ a = Tensor.full((8,), 1.0, dtype=dtypes.half).contiguous().realize()
376
+ b = Tensor.full((4,), 2.0, dtype=dtypes.half).contiguous().realize()
377
+ assign = a[:4].assign(b)
378
+ assign.realize()
379
+ np.testing.assert_allclose(a.numpy(), [2., 2., 2., 2., 1., 1., 1., 1.])
380
+
379
381
  @unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
380
382
  def test_cast_assignment(self):
381
383
  a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
@@ -387,5 +389,9 @@ class TestAssign(unittest.TestCase):
387
389
  assert oba1 is None and oba2 is None
388
390
  np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
389
391
 
392
+ def test_disk_assignment(self):
393
+ a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy()
394
+ np.testing.assert_equal(a, np.ones(5))
395
+
390
396
  if __name__ == "__main__":
391
397
  unittest.main()
@@ -1,16 +1,15 @@
1
1
  import unittest, math
2
2
  from tinygrad import Tensor, Device, dtypes
3
3
  from tinygrad.ops import Ops
4
- from tinygrad.engine.schedule import create_schedule
5
4
  from tinygrad.helpers import CI
6
5
  import numpy as np
7
6
  from tinygrad.device import is_dtype_supported
8
7
 
9
8
  def _check_ast_count(desired_count:int, t:Tensor):
10
9
  # NOTE: this has side effect because everything can be scheduled only once
11
- schedule = create_schedule(t.lazydata.lbs)
10
+ schedule = t.schedule()
12
11
  asts = [s for s in schedule if s.ast.op is Ops.SINK]
13
- assert len(asts) == desired_count
12
+ assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
14
13
 
15
14
  class TestUnaryOpsConstFolding(unittest.TestCase):
16
15
  def test_all_consts_ops(self):
@@ -95,15 +94,17 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
95
94
  _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
96
95
  def test_literal_one_pow(self):
97
96
  _check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
97
+ # TODO: pow simplification
98
98
  def test_tensor_one_pow(self):
99
- _check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
99
+ _check_ast_count(1, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
100
100
 
101
101
  # folds advance indexing into basic indexing
102
102
  class TestIndexingConstFolding(unittest.TestCase):
103
103
  def test_scalar_index(self):
104
104
  t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
105
105
  _check_ast_count(0, t[:,:,Tensor(1),:])
106
- _check_ast_count(0, t[:,:,Tensor(1)+2,:])
106
+ # NOTE: this is no longer supported because the 1+2 isn't folding early.
107
+ #_check_ast_count(0, t[:,:,Tensor(1)+2,:])
107
108
  _check_ast_count(0, t[:,:,Tensor(1),Tensor(0)])
108
109
 
109
110
  @unittest.expectedFailure
@@ -130,11 +131,12 @@ class TestMovedConstFolding(unittest.TestCase):
130
131
 
131
132
  def test_cast_padded(self):
132
133
  # NOTE: this is folded due to CAST_BEFORE_VIEW
134
+ # update: CAST_BEFORE_VIEW=1 is no longer supported
133
135
  if is_dtype_supported(dtypes.int16):
134
- _check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
136
+ _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
135
137
  np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
136
138
  if is_dtype_supported(dtypes.uint16):
137
- _check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
139
+ _check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
138
140
  np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
139
141
  # not folded
140
142
  if is_dtype_supported(dtypes.int64):
@@ -158,6 +160,37 @@ class TestReduceOpsConstFolding(unittest.TestCase):
158
160
  _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
159
161
  np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)
160
162
 
163
+ def test_bool_zero_max(self):
164
+ _check_ast_count(0, Tensor.full((1, 2), True).shrink(((0, 1), (0, 0))).max((1, 0)))
165
+ np.testing.assert_equal(Tensor.full((1, 2), True).shrink(((0, 1), (0, 0))).max((1, 0)).numpy(), False)
166
+
167
+ def test_zero_size_ops(self):
168
+ for reduceop in [lambda x:x.prod(), lambda x:x.sum()]: # lambda x:x.max() NOTE: numpy gives "reduction operation maximum which has no identity"
169
+ _check_ast_count(0, reduceop(Tensor.empty(1, 0)))
170
+ np.testing.assert_equal(reduceop(Tensor.empty(shape:=(1, 0))).numpy(), reduceop(np.empty(shape)))
171
+
172
+ def test_zero_size_ops_view(self):
173
+ for reduceop in [lambda x:x.prod(), lambda x:x.sum()]:
174
+ _check_ast_count(0, reduceop(Tensor.empty(1, 0, 4).permute((1, 2, 0)).contiguous()))
175
+ np.testing.assert_equal(reduceop(Tensor.empty(shape:=(1, 0))).numpy(), reduceop(np.empty((shape))))
176
+
177
+ def test_zero_size_ops_realized(self):
178
+ for reduceop in [lambda x:x.prod(), lambda x:x.sum()]:
179
+ _check_ast_count(0, reduceop((Tensor.randn(0, 1)+1).realize()))
180
+ np.testing.assert_equal(reduceop((Tensor.randn(shape:=(0, 1))+1).realize()).numpy(), reduceop(np.empty(shape)))
181
+
182
+ def test_zero_size_realize_folded(self):
183
+ # non contiguous folded output doesn't realize
184
+ _check_ast_count(0, Tensor.empty(1, 0).sum())
185
+ # contiguous folded const can still schedule
186
+ a = Tensor.empty(1, 0).sum().contiguous()
187
+ _check_ast_count(2, a+2)
188
+ self.assertIsNotNone(a.lazydata.base.realized)
189
+ np.testing.assert_equal((Tensor.empty(1, 0).sum().contiguous()+2).numpy(), 2)
190
+ # otherwise we just fuse it
191
+ _check_ast_count(1, (Tensor.empty(1, 0).sum()+2).contiguous())
192
+ np.testing.assert_equal((Tensor.empty(1, 0).sum()+2).numpy(), 2)
193
+
161
194
  def test_const_prod(self):
162
195
  _check_ast_count(0, Tensor.full((2, 3), fill_value=2).prod())
163
196
  np.testing.assert_equal(Tensor.full((2, 3), fill_value=2).prod().numpy(), 2**(2*3))
@@ -206,6 +239,8 @@ class TestMultiConstFolding(unittest.TestCase):
206
239
  _check_ast_count(0, t ** 1)
207
240
  _check_ast_count(0, 1 ** t)
208
241
 
242
+ # failing because multi calls .contiguous() on every single sharded uop
243
+ @unittest.expectedFailure
209
244
  def test_multi_const_folding_tensor(self):
210
245
  ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
211
246
  t = Tensor.arange(16).float().realize().to(ds)
@@ -3,7 +3,6 @@ import unittest
3
3
  from tinygrad.ops import Ops
4
4
  from tinygrad.tensor import Tensor
5
5
  from tinygrad.nn import Conv2d
6
- from tinygrad.engine.schedule import create_schedule
7
6
  from tinygrad.shape.shapetracker import ShapeTracker, View
8
7
  from tinygrad.helpers import prod
9
8
  from test.unit.test_shapetracker import shapetracker_getitem
@@ -11,13 +10,12 @@ from test.unit.test_shapetracker import shapetracker_getitem
11
10
  class TestConvShapetracker(unittest.TestCase):
12
11
  def test_conv_3x3_one_view(self):
13
12
  conv = Conv2d(16, 32, (3, 3))
14
-
15
13
  # first run to init the weights, they are scheduled.
16
- create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata])
14
+ conv(Tensor.empty(1, 16, 10, 10)).schedule()
17
15
  # run it again to get the kernels
18
- sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK]
16
+ sched = [si for si in conv(Tensor.empty(1, 16, 10, 10)).schedule() if si.ast.op is Ops.SINK]
19
17
  assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
20
- for st in [x.st_arg for x in sched[0].ast.parents if x.op is Ops.LOAD]:
18
+ for st in [x.st_arg for x in sched[0].ast.toposort if x.op is Ops.LOAD]:
21
19
  assert len(st.views) == 1
22
20
 
23
21
  def test_conv_2x2_backward_one_view(self):
@@ -26,7 +24,7 @@ class TestConvShapetracker(unittest.TestCase):
26
24
  conv(X).mean().backward()
27
25
  si = X.grad.schedule()[-1]
28
26
  print(si)
29
- ldb = [x for x in si.ast.parents if x.op is Ops.LOAD][0]
27
+ ldb = [x for x in si.ast.toposort if x.op is Ops.LOAD][0]
30
28
  st: ShapeTracker = ldb.st_arg.simplify()
31
29
  # NOTE: st.real_size() is broken
32
30
  print(si.inputs[0].size)
@@ -4,10 +4,10 @@ import torch
4
4
  from typing import Any, List
5
5
  from tinygrad.device import is_dtype_supported
6
6
  from tinygrad.helpers import getenv, DEBUG, CI
7
- from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16
7
+ from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, to_dtype
8
8
  from tinygrad import Device, Tensor, dtypes
9
9
  from tinygrad.tensor import _to_np_dtype
10
- from hypothesis import given, settings, strategies as strat
10
+ from hypothesis import assume, given, settings, strategies as strat
11
11
  from test.helpers import rand_for_dtype
12
12
  import pytest
13
13
  pytestmark = pytest.mark.filterwarnings("ignore")
@@ -35,11 +35,11 @@ def _test_to_np(a:Tensor, np_dtype, target):
35
35
  except AssertionError as e:
36
36
  raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e
37
37
 
38
- def _assert_eq(tensor:Tensor, target_dtype:DType, target):
38
+ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float=1e-7):
39
39
  if DEBUG >= 2: print(tensor.numpy())
40
40
  try:
41
41
  assert tensor.dtype == target_dtype
42
- np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, 1e-7))
42
+ np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
43
43
  except AssertionError as e:
44
44
  raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
45
45
 
@@ -52,13 +52,12 @@ def _test_cast(a:Tensor, target_dtype:DType):
52
52
  if target_dtype == dtypes.half and Device.DEFAULT == "PYTHON":
53
53
  # TODO: struct.pack cannot pack value > 65504 (max of half) into e format
54
54
  a = (a > 65504).where(65504, a)
55
- if CI and Device.DEFAULT == "CLANG" and (target_dtype, a.dtype) in [(dtypes.double, dtypes.half), (dtypes.half, dtypes.double)]:
56
- # TODO: cast between double and half are broken https://github.com/tinygrad/tinygrad/issues/4084
57
- return
58
55
 
59
56
  _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
60
57
  def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
61
58
  if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
59
+ if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
60
+ raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
62
61
  _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
63
62
 
64
63
  class TestDType(unittest.TestCase):
@@ -99,7 +98,6 @@ class TestDType(unittest.TestCase):
99
98
  get_available_cast_dtypes(self.DTYPE)
100
99
  ))
101
100
  def test_bitcast(self):
102
- if Device.DEFAULT == "WEBGL": raise unittest.SkipTest("no bitcast in WebGL GLSL")
103
101
  if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
104
102
  list(map(
105
103
  lambda dtype:
@@ -109,6 +107,9 @@ class TestDType(unittest.TestCase):
109
107
 
110
108
  def test_dtypes_fields(self):
111
109
  fields = dtypes.fields()
110
+ self.assertIn("float", fields)
111
+ self.assertIn("float32", fields)
112
+ self.assertEqual(len(fields), 24)
112
113
  self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
113
114
  self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None))
114
115
 
@@ -117,7 +118,9 @@ class TestDType(unittest.TestCase):
117
118
  data = [1., 2., 0., 0.5, -1.5, 5.25]
118
119
  for dt in dtypes:
119
120
  arr = np.asarray(data).astype(dt)
120
- tin = Tensor(arr).numpy()
121
+ tensor = Tensor(arr)
122
+ if not is_dtype_supported(tensor.dtype): continue
123
+ tin = tensor.numpy()
121
124
  tor = torch.as_tensor(arr).detach().numpy()
122
125
  assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
123
126
  np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
@@ -244,6 +247,11 @@ class TestInt8DType(TestDType):
244
247
  def test_int8_to_uint16_negative(self):
245
248
  _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint16), dtypes.uint16, [2**16-1, 2**16-2, 2**16-3, 2**16-4])
246
249
 
250
+ @unittest.skipIf(getenv("PTX"), "broken in ptx")
251
+ def test_bitcast_alt(self):
252
+ a = Tensor([72, -90, 27, 40, -53, 70, 96, 51], dtype=dtypes.int8).bitcast(dtypes.short)
253
+ self.assertListEqual(a.tolist(), [-22968, 10267, 18123, 13152])
254
+
247
255
  class TestUint8DType(TestDType):
248
256
  DTYPE = dtypes.uint8
249
257
  @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
@@ -254,7 +262,9 @@ class TestUint8DType(TestDType):
254
262
  class TestBitCast(unittest.TestCase):
255
263
  @given(strat.sampled_from(dtype_ints + dtype_floats), strat.sampled_from(dtype_ints + dtype_floats))
256
264
  def test_shape_change_bitcast(self, dt1, dt2):
257
- if dt2 == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
265
+ # NOTE: this has to be assume to prevent hypothesis from skipping all samples
266
+ assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet
267
+ assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
258
268
  data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
259
269
  _test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist())
260
270
 
@@ -355,7 +365,7 @@ class TestEqStrDType(unittest.TestCase):
355
365
  def test_strs(self):
356
366
  if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
357
367
  self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
358
- self.assertEqual(str(dtypes.float32.ptr()), "dtypes.float.ptr()")
368
+ self.assertEqual(str(dtypes.float32.ptr(16)), "dtypes.float.ptr(16)")
359
369
 
360
370
  class TestHelpers(unittest.TestCase):
361
371
  signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
@@ -538,7 +548,7 @@ class TestTypeSpec(unittest.TestCase):
538
548
  _assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
539
549
  if is_dtype_supported(dtypes.float16):
540
550
  _assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
541
- _assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7))
551
+ _assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
542
552
  _assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
543
553
  # stop-start and step have different signs
544
554
  _assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))
@@ -642,8 +652,7 @@ class TestAutoCastType(unittest.TestCase):
642
652
  def test_broadcast_scalar(self, dt):
643
653
  assert (Tensor.ones(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
644
654
  assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
645
- if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool:
646
- assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
655
+ assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
647
656
 
648
657
  def test_sum(self):
649
658
  assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
@@ -772,7 +781,8 @@ class TestAutoCastType(unittest.TestCase):
772
781
  if DEBUG >= 2:
773
782
  print(f"testing {default_dtype=}, {dtype=}")
774
783
  a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
775
- b = (a * 5).sum()
784
+ # NOTE: this is broken without default_dtype because of CAST_BEFORE_VIEW
785
+ b = (a * 5).sum(acc_dtype=default_dtype)
776
786
  b.backward() # if there is dtype mismatch, lazy should assert
777
787
  assert a.grad.dtype == a.dtype
778
788
  np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
@@ -851,5 +861,18 @@ class TestDtypeUsage(unittest.TestCase):
851
861
  t = Tensor([[1, 2], [3, 4]], dtype=d)
852
862
  (t*t).max().item()
853
863
 
864
+ class TestToDtype(unittest.TestCase):
865
+ def test_dtype_to_dtype(self):
866
+ dtype = dtypes.int32
867
+ res = to_dtype(dtype)
868
+ self.assertIsInstance(res, DType)
869
+ self.assertEqual(res, dtypes.int32)
870
+
871
+ def test_str_to_dtype(self):
872
+ dtype = "int32"
873
+ res = to_dtype(dtype)
874
+ self.assertIsInstance(res, DType)
875
+ self.assertEqual(res, dtypes.int32)
876
+
854
877
  if __name__ == '__main__':
855
878
  unittest.main()